TensorFlow学习之——checkpoints
在看別人的訓練網絡中一開頭就遇到這樣一行代碼:
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)鼠標放在函數名上,ctrl+B,或者ctrl+點擊函數名,可以跳轉到函數的定義,可以知道tf.train.get_checkpoint_state函數通過目錄下的checkpoint文件找到checkpoint狀態proto。
訓練可能分成多次迭代,在迭代期間或者訓練完成測試之前,需要將訓練得到的參數保存到一個文件中,等到需要時再從文件中讀取。TensorFlow提供了兩種模型格式:
- checkpoints:這種格式依賴于創建模型的代碼。
- SavedModel:這種格式與創建模型的代碼無關。
Checkpoints文件是這樣的一個二進制文件,好比是一個中轉站,Tensorflow針對這一需提供了Saver類把變量名映射到對應的tensor值,并可以從checkpoints文件中恢復變量。
再回到第一行代碼,返回得到的ckpt其中有model_checkpoint_path和all_model_checkpoint_paths兩個屬性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths則有未被刪除的所有tensorflow模型文件的文件名。
既然有預訓練的模型,就應該把checkpoint文件放入文件夾下。checkpoint文件其實有三個文件組成,后綴名分別是.meta和.index和.data-00000-of-00001文件。
當需要恢復某個模型的參數,繼續進行訓練時,可以使用下面的代碼(不需要加后綴,就可以同時包含三個文件),恢復訓練時的最后一個模型參數:
if args.restore_last and ckpt.model_checkpoint_path: #.model_checkpoint_path保存了最新的tensorflow模型文件的文件名# Continue training saved model 繼續訓練已經保存的模型,側面也表明之前有預訓練的模型#saver.restore(sess, ckpt.model_checkpoint_path) #恢復模型參數,繼續訓練saver.restore(sess,'checkpoints/noiseMScsC8_epoch15.ckpt-15') # 恢復模型參數,繼續訓練.預訓練了15次,config中默認512次#https://www.cnblogs.com/darkknightzh/p/7198773.htmlprint('{} restored.'.format(ckpt.model_checkpoint_path))?
總結
以上是生活随笔為你收集整理的TensorFlow学习之——checkpoints的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: hive函数大全:11大类、109个函数
- 下一篇: 【每日SQL打卡】