tensorflow 模型预训练后的参数restore finetuning
之前訓練的網絡中有一部分可以用到一個新的網絡中,但是不知道存儲的參數如何部分恢復到新的網絡中,也了解到有許多網絡是通過利用一些現有的網絡結構,通過finetuning進行改造實現的,因此了解了一下關于模型預訓練后部分參數restore和finetuning的內容
更多內容參見:
https://blog.csdn.net/mieleizhi0522/article/details/80535189
https://blog.csdn.net/leo_xu06/article/details/79200634
https://blog.csdn.net/b876144622/article/details/79962727
https://blog.csdn.net/ying86615791/article/details/76215363
首先了解一下變量(tf.Variable),變量是tf框架中用于存儲參數的對象,我們這里要恢復的參數也是variable類型的。訓練的參數是放在不同名字下的variable中的,checkpoint中存儲的變量也是通過不同的名字進行區分的,這里如果要恢復指定的參數可以使用
with tf.variable_scope('', reuse = True):sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))Saver是用于保存變量的對象。下面是saver對象的創建和調用
saver = tf.train.Saver() save_path = saver.save(sess, "/tmp/model.ckpt")?如果僅在session開始時恢復模型變量的一個子集,需要對剩下的變量執行初始化op。
# Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add ops to save and restore only 'v2' using the name "my_v2" saver = tf.train.Saver({"my_v2": v2})對已有checkpoint內容進行查看,可以使用一下代碼(來自https://blog.csdn.net/mieleizhi0522/article/details/80535189),然后就可以結合之前的指定變量名的方法對參數進行restore了。注意,在完成部分參數的restore后要記得對沒有初始化的變量進行初始化,否則報錯。
import tensorflow as tfimport osfrom tensorflow.python import pywrap_tensorflowmodel_dir=r'G:\KeTi\C3D'checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")# 從checkpoint中讀出數據reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法var_to_shape_map = reader.get_variable_to_shape_map()# 輸出權重tensor名字和值for key in var_to_shape_map:print("tensor_name: ", key,reader.get_tensor(key).shape)輸出
tensor_name: var_name/wc4a (3, 3, 3, 256, 512)tensor_name: var_name/wc3a (3, 3, 3, 128, 256)tensor_name: var_name/wd1 (8192, 4096)tensor_name: var_name/wc5b (3, 3, 3, 512, 512)tensor_name: var_name/bd1 (4096,)tensor_name: var_name/wd2 (4096, 4096)tensor_name: var_name/wout (4096, 101)tensor_name: var_name/wc1 (3, 3, 3, 3, 64)tensor_name: var_name/bc4b (512,)tensor_name: var_name/wc2 (3, 3, 3, 64, 128)tensor_name: var_name/bc3a (256,)tensor_name: var_name/bd2 (4096,)tensor_name: var_name/bc5a (512,)tensor_name: var_name/bc2 (128,)tensor_name: var_name/bc5b (512,)tensor_name: var_name/bout (101,)tensor_name: var_name/bc4a (512,)tensor_name: var_name/bc3b (256,)tensor_name: var_name/wc4b (3, 3, 3, 512, 512)tensor_name: var_name/bc1 (64,)tensor_name: var_name/wc3b (3, 3, 3, 256, 256)tensor_name: var_name/wc5a (3, 3, 3, 512, 512)總結
以上是生活随笔為你收集整理的tensorflow 模型预训练后的参数restore finetuning的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 参考使用CSDN-markdown编辑器
- 下一篇: Win10+GTX 1080Ti+Ana