Keras ImageDataGenerator用于数据扩充/增强的原理及方法
摘要
在這篇博客中,您將學(xué)習(xí)如何使用Keras的ImageDataGenerator類(lèi)執(zhí)行數(shù)據(jù)擴(kuò)充/增強(qiáng)。另外將介紹什么是數(shù)據(jù)增強(qiáng),數(shù)據(jù)增強(qiáng)的類(lèi)型,為什么使用數(shù)據(jù)增強(qiáng)以及它能做什么/不能做什么。
有三種數(shù)據(jù)增強(qiáng)類(lèi)型,默認(rèn)情況下,Keras的ImageDataGenerator該類(lèi)執(zhí)行就地/即時(shí)數(shù)據(jù)擴(kuò)充。
檢測(cè)到過(guò)度擬合的倆種解決方案是(1)減少模型容量或(2)執(zhí)行正則化。
數(shù)據(jù)增強(qiáng)是正則化的一種形式,使我們的網(wǎng)絡(luò)可以更好地將其推廣到我們的測(cè)試/驗(yàn)證集。
在訓(xùn)練中不應(yīng)用數(shù)據(jù)增強(qiáng)會(huì)導(dǎo)致過(guò)度擬合。應(yīng)用數(shù)據(jù)增強(qiáng),可以進(jìn)行平滑的訓(xùn)練,避免過(guò)度擬合以及擁有更高的準(zhǔn)確性/更低的損失。
強(qiáng)烈建議在所有的訓(xùn)練中都使用數(shù)據(jù)增強(qiáng)。
1. Keras ImageDataGenerator是什么
Keras的ImageDataGenerator在訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)中很常見(jiàn),是對(duì)待訓(xùn)練的數(shù)據(jù)集執(zhí)行一系列隨機(jī)變換后進(jìn)行訓(xùn)練模型,提高模型的通用性,使得模型具有更好的泛化能力。
在修改后的擴(kuò)充數(shù)據(jù)上訓(xùn)練的模型更有可能概括為訓(xùn)練集中未包含的示例數(shù)據(jù)點(diǎn)。
也可以通過(guò)一些簡(jiǎn)單的幾何變換得到增強(qiáng)后的圖像,如平移,旋轉(zhuǎn),放大/縮小,剪切,水平/垂直翻轉(zhuǎn)等;
對(duì)輸入圖像應(yīng)用少量的轉(zhuǎn)換將稍微改變其外觀,但不會(huì)更改類(lèi)標(biāo)簽,從而使數(shù)據(jù)增強(qiáng)成為適用于計(jì)算機(jī)視覺(jué)任務(wù)的非常自然,簡(jiǎn)便的方法。
2. Keras ImageDataGenerator工作原理
ImageDataGenerator接受原始數(shù)據(jù),對(duì)其進(jìn)行隨機(jī)轉(zhuǎn)換,并僅返回轉(zhuǎn)換后的新數(shù)據(jù)。
- 接受一批用于訓(xùn)練的圖像;
- 進(jìn)行此批處理并對(duì)批處理中的每個(gè)圖像應(yīng)用一系列隨機(jī)變換(包括隨機(jī)旋轉(zhuǎn),調(diào)整大小,剪切等);
- 用新的,隨機(jī)轉(zhuǎn)換的批次替換原始批次;
- 在此隨機(jī)轉(zhuǎn)換的批次上訓(xùn)練CNN(即原始數(shù)據(jù)本身不用于訓(xùn)練)。
3. Keras ImageDataGenerator的三種類(lèi)型
- (1)通過(guò)數(shù)據(jù)增強(qiáng)生成數(shù)據(jù)集和數(shù)據(jù)擴(kuò)展(較少見(jiàn))
這種方法存在一個(gè)問(wèn)題——尚未完全提高模型的泛化能力。
想象一下通過(guò)一張圖生成100張圖然后進(jìn)行訓(xùn)練;由于所有這些數(shù)據(jù)均基于超小型數(shù)據(jù)集。
我們不能期望在少量數(shù)據(jù)上訓(xùn)練NN,然后期望將其推廣到從未訓(xùn)練過(guò)且從未見(jiàn)過(guò)的數(shù)據(jù)。
- (2)就地/即時(shí)數(shù)據(jù)增強(qiáng)(最常見(jiàn))
這種加強(qiáng)方式是使用最普遍的,有倆個(gè)地方需要注意:
- ImageDataGenerator 是不是原始數(shù)據(jù)和變換后的數(shù)據(jù)都返回——只返回隨機(jī)變換的數(shù)據(jù)。
- 因?yàn)檫@種擴(kuò)充是在訓(xùn)練時(shí)完成的,因此稱(chēng)其為“就地”和“即時(shí)”數(shù)據(jù)擴(kuò)充(即不會(huì)在訓(xùn)練之前生成這些示例);
由于訓(xùn)練的時(shí)候用的是經(jīng)過(guò)隨機(jī)平移、旋轉(zhuǎn)、剪切等變換后的數(shù)據(jù)進(jìn)行的,因此模型具有了比較好的泛化能力,其在測(cè)試集上表現(xiàn)良好,而在訓(xùn)練集上將差一些,由于我們并沒(méi)有拿原始的訓(xùn)練數(shù)據(jù)訓(xùn)練,因此具有一定的偏差。
- (3)將數(shù)據(jù)集生成和就地?cái)U(kuò)充相結(jié)合
在訓(xùn)練數(shù)據(jù)很少,并且真實(shí)的場(chǎng)景數(shù)據(jù)比較難以收集的情況下,可以用將類(lèi)型2數(shù)據(jù)擴(kuò)充(即就地/即時(shí)數(shù)據(jù)擴(kuò)充)應(yīng)用于通過(guò)模擬收集的數(shù)據(jù)。
類(lèi)似于行為克隆,在自動(dòng)駕駛應(yīng)用中有運(yùn)用。
4. 項(xiàng)目結(jié)構(gòu)
5. 實(shí)現(xiàn)generate_images.py, train.py并訓(xùn)練CNN
- generate_images.py 生成數(shù)據(jù)增強(qiáng)后的數(shù)據(jù)集
- train.py 并進(jìn)行不同的數(shù)據(jù)增強(qiáng)后,進(jìn)行模型訓(xùn)練
(1)通用1張圖像生成100張訓(xùn)練數(shù)據(jù),并訓(xùn)練CNN; 50%的準(zhǔn)確率
(2)使用 Kaggle狗與貓的數(shù)據(jù) 集的一個(gè)子集,并在不進(jìn)行數(shù)據(jù)擴(kuò)充的情況下訓(xùn)練CNN; 64%的準(zhǔn)確率
(3)使用 Kaggle狗與貓的數(shù)據(jù) 集的一個(gè)子集,并在進(jìn)行數(shù)據(jù)擴(kuò)充的情況下訓(xùn)練CNN; 69%的準(zhǔn)確率
運(yùn)用(1)生成的訓(xùn)練精確度/損失圖
運(yùn)用(2)生成的訓(xùn)練精確度/損失圖
運(yùn)用(3)生成的訓(xùn)練精確度/損失圖【收斂的比較好,不會(huì)有精確度身高,損失也跟著升高的情況,可以完美的避開(kāi)過(guò)度擬合,并且具有比較好的泛化能力】
得出結(jié)論:
- 數(shù)據(jù)增強(qiáng)可以減少過(guò)度擬合,并提高模型進(jìn)行泛化的能力;
- 數(shù)據(jù)增強(qiáng)是一種正則化形式,保證驗(yàn)證和訓(xùn)練損失如何在幾乎沒(méi)有分歧的情況下下降。同樣,訓(xùn)練和驗(yàn)證拆分的分類(lèi)準(zhǔn)確性也一起提高;
- 通過(guò)使用數(shù)據(jù)增強(qiáng),可以克服過(guò)度擬合!
# 測(cè)試三種數(shù)據(jù)增強(qiáng)類(lèi)型后訓(xùn)練的模型情況# 第一種試驗(yàn):通用1張圖像生成100張訓(xùn)練數(shù)據(jù),進(jìn)行訓(xùn)練 50%的準(zhǔn)確率
# python train.py --dataset generated_dataset --plot plot_generated_dataset.png
# 探討數(shù)據(jù)擴(kuò)充如何通過(guò)兩次實(shí)驗(yàn)來(lái)減少過(guò)度擬合并提高模型進(jìn)行泛化的能力,獲取到了 64%的準(zhǔn)確率,檢測(cè)到過(guò)度擬合的倆種解決方案是(1)減少模型容量或(2)執(zhí)行正則化。
# 第二種試驗(yàn):不使用數(shù)據(jù)擴(kuò)充
# python train.py --dataset dogs_vs_cats_small --plot plot_dogs_vs_cats_no_aug.png
# 第三種試驗(yàn):運(yùn)用數(shù)據(jù)擴(kuò)充 研究數(shù)據(jù)增強(qiáng)如何充當(dāng)正則化形式 69%的準(zhǔn)確率 【注意驗(yàn)證和訓(xùn)練損失如何在幾乎沒(méi)有分歧的情況下下降。同樣,訓(xùn)練和驗(yàn)證拆分的分類(lèi)準(zhǔn)確性也一起提高。】
# 通過(guò)使用數(shù)據(jù)增強(qiáng),我們可以克服過(guò)度擬合!
# 強(qiáng)烈建議在任何情況下訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí)都使用 數(shù)據(jù)增強(qiáng);
# python train.py --dataset dogs_vs_cats_small --augment 1 --plot plot_dogs_vs_cats_with_aug.png# 導(dǎo)入必要的包
# 設(shè)置matplot為Agg以保存模型訓(xùn)練的plot圖到磁盤(pán)
import matplotlib
matplotlib.use("Agg")from pyimagesearch.resnet import ResNet
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# 導(dǎo)入ImageDataGenerator
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import cv2
import os# 構(gòu)建命令行參數(shù)
# --dataset 數(shù)據(jù)集的路徑
# --augment 是否使用數(shù)據(jù)增強(qiáng)方式2(1.通過(guò)數(shù)據(jù)增強(qiáng)生成數(shù)據(jù)集和數(shù)據(jù)擴(kuò)展(較少見(jiàn)) 2.就地/即時(shí)數(shù)據(jù)增強(qiáng)(最常見(jiàn)) 3.將數(shù)據(jù)集生成和就地?cái)U(kuò)充相結(jié)合 )
# --plot 保存 loss/accuracy 圖的路徑
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,help="path to input dataset")
ap.add_argument("-a", "--augment", type=int, default=-1,help="whether or not 'on the fly' data augmentation should be used")
ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output loss/accuracy plot")
args = vars(ap.parse_args())# 初始化初始學(xué)習(xí)率,批處理大小batchsize,訓(xùn)練的期數(shù)epochs
INIT_LR = 1e-1
BS = 8
EPOCHS = 50# 獲取數(shù)據(jù)集,并把數(shù)據(jù),標(biāo)簽按順序存儲(chǔ)在list中
print("[INFO] loading images...")
imagePaths = list(paths.list_images(args["dataset"]))
data = []
labels = []# 循環(huán)遍歷圖片路徑
for imagePath in imagePaths:# 從文件名中提取分類(lèi)標(biāo)簽名稱(chēng),加載圖片,忽略寬高比縮放為 64*64label = imagePath.split(os.path.sep)[-2]image = cv2.imread(imagePath)image = cv2.resize(image, (64, 64))# 更新數(shù)據(jù)、標(biāo)簽listdata.append(image)labels.append(label)# 轉(zhuǎn)換數(shù)據(jù)、標(biāo)簽list為Numpy array,并將數(shù)據(jù)的像素強(qiáng)度轉(zhuǎn)換為[0,255]
data = np.array(data, dtype="float") / 255.0# 編碼類(lèi)標(biāo)簽,由字符串轉(zhuǎn)為integer轉(zhuǎn)為 一鍵熱編碼數(shù)組(echc:[1,0]代表cats,[0,1]代表dogs)
le = LabelEncoder()
labels = le.fit_transform(labels)
labels = to_categorical(labels, 2)#分組數(shù)據(jù)為75%的訓(xùn)練數(shù)據(jù),25%的測(cè)試數(shù)據(jù)
(trainX, testX, trainY, testY) = train_test_split(data, labels,test_size=0.25, random_state=42)# 初始化數(shù)據(jù)擴(kuò)充對(duì)象(初始化一個(gè)空對(duì)象)
aug = ImageDataGenerator()# 檢查是否需要進(jìn)行數(shù)據(jù)擴(kuò)充 --augment參數(shù)的值
if args["augment"] > 0:print("[INFO] performing 'on the fly' data augmentation")# 隨機(jī)旋轉(zhuǎn),縮放,移動(dòng),剪切和翻轉(zhuǎn)。(random rotations, zooms, shifts, shears, and flips)aug = ImageDataGenerator(rotation_range=20,zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15,horizontal_flip=True,fill_mode="nearest")# 初始化優(yōu)化器和模型
# 構(gòu)建我們的ResNet,使用隨機(jī)梯度下降優(yōu)化和學(xué)習(xí)率衰減的模型。我們使用“ binary_crossentropy” 2類(lèi)問(wèn)題的損失。如果您有兩個(gè)以上的類(lèi)標(biāo)簽,請(qǐng)確保使用“ categorial_crossentropy”
print("[INFO] compiling model...")
opt = SGD(lr=INIT_LR, momentum=0.9, decay=INIT_LR / EPOCHS)
model = ResNet.build(64, 64, 3, 2, (2, 3, 4),(32, 64, 128, 256), reg=0.0001)
model.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])# 訓(xùn)練模型
# 對(duì)象分批處理數(shù)據(jù)擴(kuò)充(僅當(dāng)--augment命令行參數(shù)已設(shè)置,對(duì)象才會(huì)執(zhí)行數(shù)據(jù)擴(kuò)充)
print("[INFO] training network for {} epochs...".format(EPOCHS))
H = model.fit(x=aug.flow(trainX, trainY, batch_size=BS),validation_data=(testX, testY),steps_per_epoch=len(trainX) // BS,epochs=EPOCHS)# 評(píng)估模型
print("[INFO] evaluating network...")
predictions = model.predict(x=testX.astype("float32"), batch_size=BS)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=le.classes_))# 繪制訓(xùn)練損失/精確度圖
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["plot"])
參考:
- https://www.pyimagesearch.com/2019/07/08/keras-imagedatagenerator-and-data-augmentation/
總結(jié)
以上是生活随笔為你收集整理的Keras ImageDataGenerator用于数据扩充/增强的原理及方法的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Keras TensorFlow教程:使
- 下一篇: 无精子症如何预防