TensorFlow模型持久化
模型持久化的目的在于可以使模型訓練后的結果重復使用,節省重復訓練模型的時間。
模型保存
train.Saver類是TensorFlow提供的用于保存和還原模型的API,使用非常簡單。
import tensorflow as tf# 聲明兩個變量并計算其加和 a = tf.Variable(tf.constant([1.0, 2.0], shape=[2]), name='a') b = tf.Variable(tf.constant([3.0, 4.0], shape=[2]), name='b') result = a + b# 初始化全部變量的操作 init_op = tf.global_variables_initializer() # 定義 Saver 類對象用于保存模型 saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)saver.save(sess, "./model/model.ckpt")上面的代碼實現了一個簡單的TensorFlow模型持久化的功能。
save()函數的sess參數用于指定要保存的模型會話,save_path參數用于指定路徑。
通過Saver類的save()函數將TensorFlow模型保存到一個指定路徑下的model.ckpt文件中。
(TensorFlow模型一般會保存在文件名為.ckpt的文件中,可以省略后綴名,但是好的編程習慣是對其加以指定)
雖然上面的程序只制定了一個文件路徑,但是在這個文件目錄下回出現4個文件:
- checkpoint文件是一個文本文件,保存了一個目錄下所有的模型文件列表。該文件會被自動更新,當有更多模型被保存到model目錄下時,文件內容會更新為最新的訓練模型。
- model.ckpt.data-00000-of-00001文件是一個二進制文件,保存了TensorFlow中每一個變量的取值。
- model.ckpt.index文件是一個二進制文件,保存了每一個變量的名稱,是一個string-string的table,其中table的key值為tensor名,value值為BundleEntryProto。
- model.ckpt.meta文件是一個二進制文件,保存了計算圖的結構。
將一個模型文件分成多個文件保存的原因是TensorFlow會將模型的計算圖結構以及參數的取值分開來保存。
模型加載
TensorFlow也提供了相應的函數來加載保存的模型。
with tf.Session() as sess:saver.restore(sess, "./model/model.ckpt")print(sess.run(result))輸出:
加載模型的代碼和保存模型的代碼相似,但是省略了初始化全部變量的過程。
使用restore()函數需要在模型參數恢復前定義計算圖上的所有運算,并且變量名需要與模型存在的變量名保持一致,這樣就可以將變量的值通過已保存的模型加載進來。
有時我們可能不希望重復定義計算圖上的計算,太繁瑣了,TensorFlow提供了import_meta_graph()函數加載模型的計算圖。
import_meta_graph()函數的輸入參數為.meta文件的路徑,返回一個Saver類實例,再調用這個實例的restore()函數就可以恢復參數了。
saver = tf.train.import_meta_graph("./model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess, "./model/model.ckpt")# 獲取默認計算圖上指定節點處的張量print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))輸出:
.ckpt.meta文件保存了計算圖的結構,通過import_meta_graph()函數將計算圖導入到程序中并傳遞給saver,之后在會話中通過restore()函數對該計算圖中變量的值進行加載。
get_tensor_by_name()函數用于獲取指定節點處的張量(add:0 表示add節點的第一個輸出)。
總結
以上是生活随笔為你收集整理的TensorFlow模型持久化的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 将文件从HDFS复制到本地
- 下一篇: Python执行脚本文件将输出既能显示控