keras优化算法_目标检测算法 - CenterNet - 代码分析
吃水不忘打井人,分析github上的基于keras的實(shí)現(xiàn):
xuannianz/keras-CenterNet?github.com代碼主體結(jié)構(gòu)模型訓(xùn)練的主函數(shù)流程如下所示,該流程也是使用keras的較為標(biāo)準(zhǔn)的流程。其中代碼篇幅較大的是數(shù)據(jù)準(zhǔn)備的部分,通常的代碼也亦如此。下面按照不同的部分分別進(jìn)行說(shuō)明。
create_generators 數(shù)據(jù)集準(zhǔn)備該代碼支持Pascal VOC格式、COCO格式以及CSV格式。keras中有三個(gè)函數(shù)可以用來(lái)進(jìn)行模型的訓(xùn)練:分別是fit,fit_generator和train_on_batch。
fit(train_x, train_y, batchsize, epochs)在使用fit進(jìn)行模型訓(xùn)練時(shí),通常假設(shè)整個(gè)訓(xùn)練集都可以放入RAM,并且沒(méi)有數(shù)據(jù)增強(qiáng)(即不需要keras生成器)。常用于簡(jiǎn)單小型的數(shù)據(jù)集訓(xùn)練。
fit_generator;常常使用的模型訓(xùn)練函數(shù)fit_generator適用于大數(shù)據(jù)集無(wú)法直接全部放入內(nèi)存中,以及標(biāo)注數(shù)據(jù)較少需要使用數(shù)據(jù)增強(qiáng)來(lái)增加訓(xùn)練模型的泛化能力。fit_generator需要傳入一個(gè)數(shù)據(jù)生成器,數(shù)據(jù)生成器可以每次動(dòng)態(tài)的生成一個(gè)batchsize的訓(xùn)練數(shù)據(jù),通常我們也將數(shù)據(jù)增強(qiáng)放入數(shù)據(jù)生成器中,這樣便可以動(dòng)態(tài)的生成增強(qiáng)后的數(shù)據(jù)。在使用fit_generator時(shí),需要傳入steps_per_epoch的值,而fit函數(shù)則不需要,這是因?yàn)閒it函數(shù)的steps_per_epoch默認(rèn)等于總的訓(xùn)練數(shù)據(jù)/batchsize,而對(duì)于fit_generator來(lái)說(shuō),如果采用了數(shù)據(jù)增強(qiáng),則可以產(chǎn)生無(wú)限的batchsize訓(xùn)練數(shù)據(jù),因此需要指定該參數(shù)。
By the way,數(shù)據(jù)生成器可以使用keras的API或者直接自己手碼python的代碼,因?yàn)槠浔举|(zhì)上也就是python的函數(shù)。
train_on_batch(batchX, batchY)train_on_batch用于需要對(duì)訓(xùn)練迭代進(jìn)行精細(xì)控制,給其傳入一批數(shù)據(jù)即可(數(shù)據(jù)大小任意),不需要提供batchsize的大小。通常很少使用該函數(shù)進(jìn)行模型訓(xùn)練。
- 本算法的實(shí)現(xiàn)過(guò)程就是采用的fit_generator進(jìn)行的模型訓(xùn)練。因此需要為其構(gòu)建數(shù)據(jù)生成器。common.py文件:class Generator(keras.utils.Sequence)構(gòu)建數(shù)據(jù)生成器的基類,咱們先說(shuō)道說(shuō)道keras.utils.Sequence這個(gè)類。
Generator類可以當(dāng)成一個(gè)抽象基類,其中主要實(shí)現(xiàn)的是batch的劃分、數(shù)據(jù)增強(qiáng)的處理、以及標(biāo)注數(shù)據(jù)的轉(zhuǎn)換(將bounding box的標(biāo)注形式轉(zhuǎn)換成高斯分布的標(biāo)注)。而真正使用的數(shù)據(jù)集的生成器如下所示。主要按照不同的數(shù)據(jù)集生成的類,并均都繼承于Generator抽象類,這里區(qū)分不同的數(shù)據(jù)集主要為了能方便區(qū)分其不同的數(shù)據(jù)標(biāo)注格式,使用起來(lái)更為方便。主要是load_annotations()和load_image()函數(shù)的實(shí)現(xiàn)。至此數(shù)據(jù)生成器便構(gòu)建完成了。
class PascalVocGenerator(Generator) class CocoGenerator(Generator)centernet網(wǎng)絡(luò)構(gòu)建算法實(shí)現(xiàn)采用的Resnet50作為網(wǎng)絡(luò)的backbone,采用下述引用網(wǎng)絡(luò)。網(wǎng)絡(luò)構(gòu)建這里相對(duì)就比較簡(jiǎn)單了,取出Resnet的C5,先添加了一層dropout,然后進(jìn)行了上采樣,然后分別構(gòu)建網(wǎng)絡(luò)head,主要有三支:中心點(diǎn)預(yù)測(cè)、中心點(diǎn)偏移值預(yù)測(cè)以及bouding box的size預(yù)測(cè)。
from keras.applications.resnet50 import ResNet50最后構(gòu)建model,使用keras的Lambda層構(gòu)建loss,作為model的output
loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input]) model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_])預(yù)訓(xùn)練模型權(quán)重加載keras的模型加載可以使用load_weights來(lái)實(shí)現(xiàn),其模型加載可以按照模型結(jié)構(gòu)加載,此時(shí)by_name需設(shè)置為False。否則將按照網(wǎng)絡(luò)層的名字來(lái)加載,此時(shí)通常將skip_mismatch也設(shè)置成True,即僅加載名字相同的層,其他名字不同的層直接跳過(guò)。因此可以利用這個(gè)特性,對(duì)已訓(xùn)練好的網(wǎng)絡(luò)局部進(jìn)行修改,然后再加載之前訓(xùn)練好的模型,方便進(jìn)行模型的調(diào)優(yōu)。
model.load_weights(args.snapshot, by_name=True, skip_mismatch=True)模型配置其中l(wèi)oss參數(shù)的傳遞有幾種形式。
- 目標(biāo)函數(shù)/損失函數(shù)的字符串,比如keras內(nèi)置的一些損失函數(shù)
- 目標(biāo)函數(shù)/損失函數(shù),通常為自定義的損失函數(shù)
- 將目標(biāo)函數(shù)/損失函數(shù)定義成model的一個(gè)層,類似本代碼的實(shí)現(xiàn)。本代碼實(shí)現(xiàn)時(shí),因?yàn)橹苯影裭oss作為model的輸出,因此輸入y_true和y_pred,實(shí)際使用y_pred即輸出loss,對(duì)其進(jìn)行優(yōu)化。
總結(jié)
以上是生活随笔為你收集整理的keras优化算法_目标检测算法 - CenterNet - 代码分析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: php5.4源码下载,WordPress
- 下一篇: java环境变量一闪而过_Java环境变