【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型
本文主要講解在現(xiàn)有常用模型基礎(chǔ)上,如何微調(diào)模型,減少訓(xùn)練時(shí)間,同時(shí)保持模型檢測(cè)精度。
首先介紹下Slim這個(gè)Google公布的圖像分類(lèi)工具包,可在github鏈接:modules and examples built with tensorflow 中找到slim包。
上面這個(gè)鏈接目錄下主要包含:
official models(這個(gè)是用Tensorflow高層API做的例子模型集,建議初學(xué)者可嘗試);
research models(這個(gè)是很多研究者利用tensorflow做的模型集,這個(gè)不是官方提供的,是研究者個(gè)人在維護(hù)的);
samples folder (包含代碼片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代碼呈現(xiàn));
而我說(shuō)的slim工具包就在research文件夾下。
Slim庫(kù)結(jié)構(gòu)
不僅定義了很多接口,還提供了很多ImageNet數(shù)據(jù)集上常用的網(wǎng)絡(luò)結(jié)構(gòu)和預(yù)訓(xùn)練模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。
?
下面用slim工具包中的文件來(lái)對(duì)自己的數(shù)據(jù)集做訓(xùn)練,訓(xùn)練可分為利用已有的模型架構(gòu)(如常見(jiàn)的VGG,Inception等的卷積,池化這些結(jié)構(gòu))來(lái)全新訓(xùn)練權(quán)重文件或是微調(diào)權(quán)重文件。由于很多已有的imagenet圖像數(shù)據(jù)覆蓋面已經(jīng)很廣,基于此訓(xùn)練的網(wǎng)絡(luò)權(quán)重已經(jīng)能提取大致的目標(biāo)特征(從低微像素到高維的結(jié)構(gòu)特征),所以可使用fine-tune只訓(xùn)練框架中某些層的權(quán)重,當(dāng)然根據(jù)自己數(shù)據(jù)集做全部權(quán)重重新訓(xùn)練的檢測(cè)效果理論會(huì)更好些,需要權(quán)衡時(shí)間成本和檢測(cè)精度的需求了;
下面會(huì)依據(jù)成熟網(wǎng)絡(luò)結(jié)構(gòu)Incvption V3分別做權(quán)重文件的全部重新訓(xùn)練和部分重新訓(xùn)練(即fine-tune)來(lái)介紹;
(前提是你將slim工具庫(kù)下載下來(lái),安裝了必要的tensorflow等框架;并且根據(jù)訓(xùn)練圖像制作完成了tfrecord文件)
有關(guān)tfrecord訓(xùn)練文件的制作請(qǐng)參考:將圖像制作成tfrecord
step1:定義新的datasets數(shù)據(jù)集文件
在slim/datasets/文件夾下 添加一個(gè)python文件,直接復(fù)制一份flowers.py,重命名為“satellite.py”(這個(gè)名字可根據(jù)你實(shí)際的數(shù)據(jù)集名字來(lái)更改,我用的是何大神的航拍圖數(shù)據(jù)集)
需要對(duì)賦值生成后的satellite.py內(nèi)容做如下修改:
_FILE_PATTERN = 'flowers_%s_*.tfrecord'?
更改為
_FILE_PATTERN = 'satellite_%s_*.tfrecord' ? ?? #這個(gè)主要是根據(jù)你之前制作的tfrecord文件名來(lái)改的,我制作的訓(xùn)練文件為satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,驗(yàn)證文件為satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord。
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
更改為
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200} ?#這個(gè)根據(jù)自己訓(xùn)練和驗(yàn)證樣本數(shù)量來(lái)改,我的訓(xùn)練數(shù)據(jù)是800張圖/類(lèi),共6類(lèi),驗(yàn)證集時(shí)200張/類(lèi),共6類(lèi);
_NUM_CLASSES = 5
更改為
_NUM_CLASSES = 6 ? ? ? #實(shí)際訓(xùn)練類(lèi)別為6類(lèi);
?
還需要對(duì)satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),這行代碼做更改,由于用的數(shù)據(jù)集源文件都是XXXX.jpg格式,因此將默認(rèn)的圖像格式轉(zhuǎn)為jpg,更改后為'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,對(duì)satellite.py文件完成制作與更改(其源碼如下):
satellite.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides data for the flowers dataset.The dataset scripts used to create the dataset can be found at: tensorflow/models/slim/datasets/download_and_convert_flowers.py """from __future__ import absolute_import from __future__ import division from __future__ import print_functionimport os import tensorflow as tffrom datasets import dataset_utilsslim = tf.contrib.slim_FILE_PATTERN = 'satellite_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}_NUM_CLASSES = 6_ITEMS_TO_DESCRIPTIONS = {'image': 'A color image of varying size.','label': 'A single integer between 0 and 4', }def get_split(split_name, dataset_dir, file_pattern=None, reader=None):"""Gets a dataset tuple with instructions for reading flowers.Args:split_name: A train/validation split name.dataset_dir: The base directory of the dataset sources.file_pattern: The file pattern to use when matching the dataset sources.It is assumed that the pattern contains a '%s' string so that the splitname can be inserted.reader: The TensorFlow reader type.Returns:A `Dataset` namedtuple.Raises:ValueError: if `split_name` is not a valid train/validation split."""if split_name not in SPLITS_TO_SIZES:raise ValueError('split name %s was not recognized.' % split_name)if not file_pattern:file_pattern = _FILE_PATTERNfile_pattern = os.path.join(dataset_dir, file_pattern % split_name)# Allowing None in the signature so that dataset_factory can use the default.if reader is None:reader = tf.TFRecordReaderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)labels_to_names = Noneif dataset_utils.has_labels(dataset_dir):labels_to_names = dataset_utils.read_label_file(dataset_dir)return slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,num_samples=SPLITS_TO_SIZES[split_name],items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,num_classes=_NUM_CLASSES,labels_to_names=labels_to_names)step2:注冊(cè)數(shù)據(jù)庫(kù)
接下來(lái)對(duì)slim/datasets/dataset_factory.py文件做更改,注冊(cè)下satellite數(shù)據(jù)庫(kù);修改之處如下(添加了兩行紅色字體代碼):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
? ? 'cifar10': cifar10,
? ? 'flowers': flowers,
? ? 'imagenet': imagenet,
? ? 'mnist': mnist,
?? ?'satellite': satellite,
?? ?
}
step3:準(zhǔn)備訓(xùn)練文件夾
在slim文件夾下新建如下目錄文件夾,并將對(duì)應(yīng)的文件放在相應(yīng)目錄下
slim/
? ? satellite/
? ? ? ? ? ? ? data/
? ? ? ? ? ? ? ? ? ?satellite_train_00000-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_train_00001-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_validation_00000-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_validation_00001-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?label.txt
? ? ? ? ? ? ? pretrained/
? ? ? ? ? ? ? ? ? ?inception_v3.ckpt
? ? ? ? ? ? ? train_dir/
data文件夾下存放你制作的tfrecord訓(xùn)練測(cè)試文件和標(biāo)簽名;
pretrained文件夾下存放官網(wǎng)訓(xùn)練的權(quán)重文件;下載地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz ? ? ?
train_dir文件夾下存放你訓(xùn)練得到的模型和日志;
step4-1:在現(xiàn)有模型結(jié)構(gòu)上fine-tune
開(kāi)始訓(xùn)練,在slim文件夾下,運(yùn)行如下指令可開(kāi)始訓(xùn)練(主要是訓(xùn)練邏輯層):
python train_image_classifier.py \--train_dir=satellite/train_dir \--dataset_name=satellite \--dataset_split_name=train \--dataset_dir=satellite/data \--model_name=inception_v3 \--checkpoint_path=satellite/pretrained/inception_v3.ckpt \--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--max_number_of_steps=100000 \--batch_size=32 \--learning_rate=0.001 \--learning_rate_decay_type=fixed \--save_interval_secs=300 \--save_summaries_secs=2 \--log_every_n_steps=10 \--optimizer=rmsprop \--weight_decay=0.00004命令參數(shù)解析如下:
? --trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先來(lái)解 釋參數(shù)trainable_scopes 的作用,因?yàn)榉浅V匾?trainable_scopes 規(guī)定了在模型中fine-tune變量的范圍 。 這里的設(shè)定表示只對(duì) InceptionV3/Logits, Inception V3/ AuxLogits 兩個(gè)變量進(jìn)行微調(diào),其他變量都保持不動(dòng) 。 Inception V3/Logits,Inception V3/ AuxLogits 就相當(dāng)于在網(wǎng)絡(luò)中的 fc8 ,它們是 Inception V3的“末端層” 。 如果不設(shè)定 trainable_scopes , 就會(huì)對(duì)模型中所有的參數(shù)進(jìn)行訓(xùn)練。
? --train_dir=satellite/train_dir:表明會(huì)在 satellite/train_dir目錄下保存日志和checkpoint。
? --dataset_name=satellite、 --dataset_split_ name=train: 指定訓(xùn)練的數(shù)據(jù)集。
? --dataset_dit=satellite/data:指定訓(xùn)練數(shù)據(jù)集保存的位置。?
? --model_ name=inception _ v3 :使用的模型名稱(chēng)。?
? --checkpoint_path=satellite/pretrained/inception_v3.ckpt:預(yù)訓(xùn)練模型的保存位置。
? --checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢復(fù)預(yù)訓(xùn)練模型時(shí),不恢復(fù)這兩層。正如之前所說(shuō),這兩層是 Inception V3 模型的末端層,對(duì)應(yīng)著 ImageNet 數(shù)據(jù)集的 1000 類(lèi),和相當(dāng)前的數(shù)據(jù)集不符,因此不要去恢復(fù)它。
? --max_number_of_steps 100000:最大的執(zhí)行步數(shù)。
? --batch_size=32:每步使用的 batch 數(shù)量。
? --learning_rate=0.001 : 學(xué)習(xí)率。
? --learning_rate_decay_type=fixed:學(xué)習(xí)率是否自動(dòng)下降,此處使用固定的學(xué)習(xí)率。
? --save_interval_secs=300:每隔 300s,程序會(huì)把當(dāng)前模型保存到train_dir中。 此處就是目錄 satellite/train_dir。
? --save_summaries_secs=2:每隔 2s,就會(huì)將日志寫(xiě)入到 train_dir 中。可以用 TensorBoard 查看該日志。此處為了方便觀察,設(shè)定的時(shí)間間隔較多,實(shí)際訓(xùn)練時(shí),為了性能考慮,可以設(shè)定較長(zhǎng)的時(shí)間間隔。
? --log_every_n_steps=10:每隔10步,就會(huì)在屏上打出訓(xùn)練信息。
? --optimizer=msprop:表示選定的優(yōu)化器。
? --weight_decay=0.00004:選定的 weight_decay 值。 即模型中所高參數(shù)的 二次正則化超參數(shù)。
以上命令是只訓(xùn)練末端層 InceptionV3/Logits,Inception V3/ AuxLogits ,還 可以使用以下命令對(duì)所高層進(jìn)行訓(xùn)練:
step4-2:訓(xùn)練整個(gè)模型權(quán)重?cái)?shù)據(jù)
使用以下命令對(duì)所有層進(jìn)行訓(xùn)練:
去掉 了--trainable_scopes 參數(shù)
當(dāng)train_image_classifier.py程序啟動(dòng)后,如果訓(xùn)練文件夾(即satellite/train_dir)里沒(méi)再已經(jīng)保存的模型,就會(huì)加載 checkpoint_path 中的預(yù)訓(xùn)練模型,緊接著,程序會(huì)把初始模型保存到 train_dir中 ,命名為 model.ckpt-0, 0 表示第 0 步。 這之后,每隔 5min (參數(shù)一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序還會(huì)把當(dāng)前模型保存到同樣的文件夾中 , 命名恪式和第一次保存的格式一樣。 因?yàn)槟P捅容^大,程序只會(huì)保留最新的 5 個(gè)模型。
此外,如果中斷了程序并再次運(yùn)行,程序會(huì)首先檢查 train_dir 中有無(wú)已經(jīng)保存的模型,如果有,就不會(huì)去加載 checkpoint_path 中的預(yù)訓(xùn)練模型, 而是直接加載 train_dir 中已經(jīng)訓(xùn)練好的模型,并以此為起點(diǎn)進(jìn)行訓(xùn)練。 Slim 之所以這樣設(shè)計(jì),是為了在微調(diào)網(wǎng)絡(luò)的時(shí)候,可以方便地按階段手動(dòng)調(diào)整學(xué)習(xí)率等參數(shù)。
?
至此用slim工具包做fine-tune或重新訓(xùn)練的步驟就完成了。
相似文章參考:https://blog.csdn.net/chaipp0607/article/details/74139895
總結(jié)
以上是生活随笔為你收集整理的【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 工作介绍xml书包文件
- 下一篇: 为什么厉害的人(我)都精力那么好?我有四