TensorFlow(1)-模型相关基础概念
TensorFlow-1
- 1.Graph對象
- 2.Session對象
- 3.Variabels變量
- 4. placeholders與feed_dict
- 5. tf.train.Saver() 模型參數保存、加載
Tensorflow 中文官網教程–2.0版本的官方教程
TensorFlow教程:TensorFlow快速入門教程(非常詳細)
pytorch Vs tensorflow
Tensorflow來源: 由較低級別的符號計算庫(例如: Theano) 和 較高級別的網絡規范庫(例如:Blocks 和 Lasagne)組合而成。
TensorFlow的劣勢: 1.x API混亂冗余;2.x重點關注tf.keras,棄用其他API。但是1.x 和2.x 的兼容性堪憂。
TensorFlow的計算方式: 創建數據流圖,將數據放入數據流圖中計算。數據流圖中的節點表示數學操作(op:operation),連線表示tensor 流動的通道。每個節點獲得若干個tensor,執行計算后產生若干個tensor。數據流圖在會話(Session)中啟動運行。
import tensorflow as tf1.Graph對象
計算圖–可以認為是詳細的流程圖,其包括每一步的操作[op]和變量名字。
顯式構建,graph構造函數無需接受任何參數
隱式構建,當Tensorflow庫被加載時,它會自動的創建一個Graph對象,并將 其作為默認的數據流圖。獲取默認數據流圖具柄:
default_graph = tf.get_default_graph2.Session對象
session負責分配資源,計算operation, 得出結果。 Session構造函數可接受3個參數。
使用run 來運行相應的計算操作,得到fetches中的張量值。tf.Session.run()函數返回值為fetches指定的執行結果。如果fetches是一個元素就返回一個值;若fetches是一個list,則返回list的值,若fetches是一個字典類型,則返回和fetches同keys的字典。
session.run( fetches, feed_dict=None, options=None, run_metadata=None)
Session的開啟涉及具體運算,比較消耗資源。在使用結束后,建議關閉Session。
-> 手動關閉關閉session
->自動關閉session 使用with來限定session 的作用范圍
a = tf.constant(1, dtype=tf.int8) b = tf.constant(2, dtype=tf.int8) res= a + b with tf.Session() as sess: # 運算結束后session自動關閉sess.run(res) # 執行運算res.eval() # tensor eval() 方法和 sess.run(res)效果一致參考文檔:TensorFlow中Session的使用
3.Variabels變量
Variabels 類型的參數可通過梯度下降更新、訓練。必須明確的初始化而且可以通過Saver保存到磁盤上。
定義Variabels 類型的參數: 使用tf.Variable(tensor)封裝相應的tensor
在graph 中若含有variabels 類型的變量,必須使用在session中顯式調用 初始化函數。
-> 全量variable 初始化 tf.global_variables_initializer()
-> 部分variable 初始化 tf.variables_initializer()
var1 = tf.Variable(0,name="initialize_me") var2 = tf.Variable(1,name="no_initialization") init = tf.variables_initializer([var1],name="init_var1") with tf.Session() as sess:sess.run(init)4. placeholders與feed_dict
tf.placeholder 創建占位符號,比如模型的輸入數據,其只有在訓練與預測時才會有值。賦值時,使用feed_dict 進行賦值操作。
import numpy as np x = tf.placeholder(tf.float32, shape=(10, 10)) # 10行10列 y = tf.matmul(x, x) # 矩陣乘 with tf.Session() as sess: #print(sess.run(y)) # ERROR: 此處x還沒有賦值. rand_array = np.random.rand(1024, 1024) print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.5. tf.train.Saver() 模型參數保存、加載
v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") init = tf.initialize_all_variables() # Add ops to save and restore all the variables. saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, save the variables to disk. with tf.Session() as sess:sess.run(init)save_path = saver.save(sess, "/tmp/model.ckpt")print "Model saved in file: ", save_path模型恢復時,variable 不需要初始化
v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess:saver.restore(sess, "/tmp/model.ckpt")print "Model restored."總結
以上是生活随笔為你收集整理的TensorFlow(1)-模型相关基础概念的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: windows下关于Objective-
- 下一篇: Redis集群添加节点