Keras的回调函数
訓(xùn)練模型時,很多事情一開始無法預(yù)測。尤其是你不知道需要多少輪次才能得到最佳驗證損失。通常簡單的辦法是:訓(xùn)練足夠多的輪次,這時模型已經(jīng)開始過擬合了,根據(jù)第一次運行來確定訓(xùn)練所需要的正確輪次,然后使用這個最佳輪數(shù)從頭開始啟動一個新的訓(xùn)練。當然,這種方法很浪費。更好的辦法是使用回調(diào)函數(shù)。
ModelCheckpoing和EarlyStopping回調(diào)函數(shù)
如果監(jiān)控的目標在設(shè)定輪數(shù)內(nèi)不再改善,可以用EarlyStopping回調(diào)函數(shù)來中斷訓(xùn)練。這個回調(diào)函數(shù)通常與ModelCheckpoint結(jié)合使用,后者可以在訓(xùn)練過程中持續(xù)的不斷保存模型(你也可以選擇只保存目前的最佳模型,即一輪結(jié)束后具有最佳性能的模型)
import keras # 通過fit的callbacks參數(shù)將回調(diào)函數(shù)傳入模型中,這個參數(shù)接收一個回調(diào)函數(shù)列表,你可以傳入任意個回調(diào)函數(shù) callback_lists = [keras.callbacks.EarlyStopping(monitor = 'acc', # 監(jiān)控模型的驗證精度patience = 1,), # 如果精度在多于一輪的時間(即兩輪)內(nèi)不再改善,就中斷訓(xùn)練# ModelCheckpoint用于在每輪過后保存當前權(quán)重keras.callbacks.ModelCheckpoint(filepath = 'my_model.h5', # 目標文件的保存路徑# 這兩個參數(shù)的意思是,如果val_loss沒有改善,那么不需要覆蓋模型文件,# 這就可以始終保存在訓(xùn)練過程中見到的最佳模型monitor = 'val_loss', save_best_only = True,) ]model.compile(optimizer = 'rmsprop', loss = 'binary_crossentropy', metrics = ['acc'])# 由于回調(diào)函數(shù)要監(jiān)控驗證損失和驗證精度,所以在調(diào)用fit時需要傳入validation_data(驗證數(shù)據(jù)) model.fit(x, y, epochs = 10,batch_size = 32,callbacks = callbacks_list,validation_data = (x_val, y_val))ReduceLROnPlateau回調(diào)函數(shù)
如果驗證損失不再改善,你可以使用這個回調(diào)函數(shù)來降低學(xué)習(xí)率。在訓(xùn)練過程中如果出現(xiàn)了損失平臺(loss plateau),那么增大或減小學(xué)習(xí)率都是跳出局部最小值的有效策略。這里可以用到ReduceLROnPlateau回調(diào)函數(shù)。
callback_list = [keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', # 監(jiān)控模型的驗證損失factor = 0.1, # 觸發(fā)時將學(xué)習(xí)率除以10patience = 10) # 如果驗證損失在10輪內(nèi)都沒有改善,那么就觸發(fā)這個回調(diào)函數(shù) ]# 由于回調(diào)函數(shù)要監(jiān)控驗證損失和驗證精度,所以在調(diào)用fit時需要傳入validation_data(驗證數(shù)據(jù)) model.fit(x, y, epochs = 10,batch_size = 32,callbacks = callbacks_list,validation_data = (x_val, y_val))編寫自己的回調(diào)函數(shù)
如果你需要在訓(xùn)練過程中采取特定行動,而這項行動又沒有包含在內(nèi)置的回調(diào)函數(shù)中,那么可以編寫自己的回調(diào)函數(shù)。這種回調(diào)函數(shù)的實現(xiàn)形式是創(chuàng)建keras.callbacks.Callback類的子類。然后你可以實現(xiàn)下面這些方法,它們分別在訓(xùn)練過程中的不同時間被調(diào)用。
on_epoch_begin # 在每輪開始時被調(diào)用 on_epoch_end # 在每輪結(jié)束時被調(diào)用on_batch_begin # 在處理每個批量之前被調(diào)用 on_batch_end # 在處理每個批量之后被調(diào)用on_train_begin # 在訓(xùn)練開始時被調(diào)用 on_train_end # 在訓(xùn)練結(jié)束時被調(diào)用這些方法被調(diào)用時都有一個logs參數(shù),這個參數(shù)是一個字典,里面包好前一個批量、前一個輪次或前一次訓(xùn)練的信息(訓(xùn)練的指標和驗證指標等)。此外,回調(diào)函數(shù)還可以訪問下列屬性:
- self.model:調(diào)用回調(diào)函數(shù)的模型實例
- self.validation_data:傳入fit作為驗證數(shù)據(jù)的值
下面是一個回調(diào)函數(shù)的簡單示例,它可以在每輪結(jié)束后將模型每層激活保存到硬盤(格式為Numpy數(shù)組),這個激活是對驗證集的第一個樣本計算得到的。
import keras import numpy as npclass ActivationLogger(keras.callbacks.Callback):def set_model(self, model):self.model = model # 在訓(xùn)練之前由父模型調(diào)用,告訴回調(diào)函數(shù)是哪個模型在調(diào)用它layer_outputs = [layer.output for layer in model.layers]self.activations_model = keras.models.Model(model.input,layer_outputs) # 模型實例,返回每層的激活def on_epoch_end(self, epoch, logs = None):if self.validation_data is None:raise RuntimeError('Requires validation_data')validation_sample = self.validation_data[0][0:1] # 獲取驗證數(shù)據(jù)的第一個輸入樣本activation = self.activations_model.predict(validation_sample)f = open('activation_at_epoch_' + str(epoch) + '.npz', 'w') # 將數(shù)據(jù)保存到硬盤np.savez(f, activations)f.close() 《新程序員》:云原生和全面數(shù)字化實踐50位技術(shù)專家共同創(chuàng)作,文字、視頻、音頻交互閱讀總結(jié)
以上是生活随笔為你收集整理的Keras的回调函数的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Keras函数式API
- 下一篇: 数的四则运算