深度学习笔记 —— 微调
生活随笔
收集整理的這篇文章主要介紹了
深度学习笔记 —— 微调
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
通常希望能在很大的數據集上訓練好的模型能夠幫助提升精度。
一部分做特征抽取,一部分做線性分類。
核心思想:在源數據集(通常是比較大的數據集)上訓練的模型,我們覺得可以把做特征提取那塊拿來用。(越底層的特征越為通用)
在自己的數據集上訓練的時候,使用一個與pre-train一樣架構的模型,做除了最后一層的初始化的時候,不再是隨機的初始化,而是使用pre-train訓練好的weight(可能與最終的結果很像了,總好于隨機的初始化),等價于把特征提取模塊復制過來作為我初始化的模型,使得我一開始就能做到還不錯的特征表達。(最后一層標號不一樣,所以可以隨機初始化)
?已經跟最優解比較接近了,所以使用更小的學習率和更少的迭代次數。
在微調的時候不去改變底層的類別的權重,將其固定住,不再變化那些參數,模型的復雜度也就降低了。?
在數據集很小的情況下,如果覺得全部參數參與訓練容易過擬合,可以考慮固定住底部一些層的參數,不參與更新。
?
import os import torch import torchvision from torch import nn from d2l import torch as d2l import matplotlib.pyplot as plt# save d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5') data_dir = d2l.download_extract('hotdog') train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train')) test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test')) hotdogs = [train_imgs[i][0] for i in range(8)] not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)] d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4) plt.show()# 使用RGB通道的均值和標準差,以標準化每個通道 # 因為在ImageNet上訓練的模型做了這樣的處理,所以此處做同樣的處理 normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])# 定義和初始化模型 # pretrained=True,說明不僅把模型的定義拿過來,同樣把訓練好的參數也拿過來 pretrained_net = torchvision.models.resnet18(pretrained=True) print(pretrained_net.fc) finetune_net = torchvision.models.resnet18(pretrained=True) # 最后的輸出層隨機初始化成一個線性層,此處是一個二分類問題 finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) # 只對最后一層的weight做初始化 nn.init.xavier_uniform_(finetune_net.fc.weight)# 微調模型 def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")# 把除了最后一層的所有層都拿出來,用較小的學習率;最后一層的學習率乘以10,希望其學習更快if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)train_fine_tuning(finetune_net, 5e-5)# 進行對比,不設置pretrained scratch_net = torchvision.models.resnet18() scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2) # 采用較大的學習率 train_fine_tuning(scratch_net, 5e-4, param_group=False)總結
以上是生活随笔為你收集整理的深度学习笔记 —— 微调的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 分页控件-Kaminari
- 下一篇: Ruby语言的特别之处