【神经网络扩展】:断点续训和参数提取
課程來(lái)源:人工智能實(shí)踐:Tensorflow筆記2
文章目錄
- 前言
- 斷點(diǎn)續(xù)訓(xùn)主要步驟
- 參數(shù)提取主要步驟
- 總結(jié)
前言
本講目標(biāo):斷點(diǎn)續(xù)訓(xùn),存取最優(yōu)模型;保存可訓(xùn)練參數(shù)至文本
斷點(diǎn)續(xù)訓(xùn)主要步驟
讀取模型:
先定義出存放模型的路徑和文件名,命名為.ckpt文件。
生成ckpt文件的時(shí)候會(huì)同步生成索引表,所以通過(guò)判斷是否存在索引表來(lái)知曉是不是已經(jīng)保存過(guò)模型參數(shù)。
如果有了索引表就利用load_weights函數(shù)讀取已經(jīng)保存的模型參數(shù)。
code:
checkpoint_save_path = "./checkpoint/fashion.ckpt" if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)保存模型:
保存模型參數(shù)可以使用TensorFlow給出的回調(diào)函數(shù),直接保存訓(xùn)練出來(lái)的模型參數(shù)
tf.keras.callbacks.ModelCheckpoint( filepath=路徑文件名(文件存儲(chǔ)路徑),
save_weights_only=True/False,(是否只保留參數(shù)模型)
save_best_only=True/False(是否只保留最優(yōu)結(jié)果)) 執(zhí)行訓(xùn)練過(guò)程中時(shí),加入callbacks選項(xiàng):
history=model.fit(callbacks=[cp_callback])
code:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])第一次運(yùn)行:
第二次運(yùn)行:可以發(fā)現(xiàn)模型并不是從初始訓(xùn)練,而是在基于保存的模型開(kāi)始訓(xùn)練的(這一點(diǎn)可以從準(zhǔn)確率和損失看出):
全部代碼:
參數(shù)提取主要步驟
設(shè)置打印的格式,使所有參數(shù)都打印出來(lái)
np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)
將所有可訓(xùn)練參數(shù)存入文本:
file = open('./weights.txt', 'w') for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n') file.close()完整代碼:
import tensorflow as tf import os import numpy as npnp.set_printoptions(threshold=np.inf)fashion = tf.keras.datasets.fashion_mnist (x_train, y_train), (x_test, y_test) = fashion.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax') ])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt" if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback]) model.summary()print(model.trainable_variables) file = open('./weights.txt', 'w') for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n') file.close()效果:
總結(jié)
課程鏈接:MOOC人工智能實(shí)踐:TensorFlow筆記2
總結(jié)
以上是生活随笔為你收集整理的【神经网络扩展】:断点续训和参数提取的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 地下城的暗帝什么时候出觉醒
- 下一篇: 为什么两层3*3卷积核效果比1层5*5卷