【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类
生活随笔
收集整理的這篇文章主要介紹了
【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
「@Author:Runsen」
上次微調了Alexnet,這次微調ResNet實現男人和女人圖像分類。
ResNet是 Residual Networks 的縮寫,是一種經典的神經網絡,用作許多計算機視覺任務。
ResNet論文參見此處:
https://arxiv.org/abs/1512.03385
該模型是 2015 年 ImageNet 挑戰賽的獲勝者。ResNet 的根本性突破是它使我們能夠成功訓練 150 層以上的極深神經網絡。
下面是resnet18的整個網絡結構:
Resnet 18 是在 ImageNet 數據集上預訓練的圖像分類模型。
這次使用Resnet 18 實現分類性別數據集,
該性別分類數據集共有58,658 張圖像。(train:47,009 / val:11,649)
femalemaleDataset: Kaggle Gender Classification Dataset
加載數據集
設置圖像目錄路徑并初始化 PyTorch 數據加載器。和之前一樣的模板套路
import?torch import?torch.nn?as?nn import?torch.optim?as?optimimport?torchvision from?torchvision?import?datasets,?models,?transformsimport?numpy?as?np import?matplotlib.pyplot?as?pltimport?time import?osdevice?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu")?#?device?objecttransforms_train?=?transforms.Compose([transforms.Resize((224,?224)),transforms.RandomHorizontalFlip(),?#?data?augmentationtransforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])?#?normalization ])transforms_val?=?transforms.Compose([transforms.Resize((224,?224)),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225]) ])data_dir?=?'./gender_classification_dataset' train_datasets?=?datasets.ImageFolder(os.path.join(data_dir,?'Training'),?transforms_train) val_datasets?=?datasets.ImageFolder(os.path.join(data_dir,?'Validation'),?transforms_val)train_dataloader?=?torch.utils.data.DataLoader(train_datasets,?batch_size=16,?shuffle=True,?num_workers=4) val_dataloader?=?torch.utils.data.DataLoader(val_datasets,?batch_size=16,?shuffle=True,?num_workers=4)print('Train?dataset?size:',?len(train_datasets)) print('Validation?dataset?size:',?len(val_datasets))class_names?=?train_datasets.classes print('Class?names:',?class_names) plt.rcParams['figure.figsize']?=?[12,?8] plt.rcParams['figure.dpi']?=?60 plt.rcParams.update({'font.size':?20})def?imshow(input,?title):#?torch.Tensor?=>?numpyinput?=?input.numpy().transpose((1,?2,?0))#?undo?image?normalizationmean?=?np.array([0.485,?0.456,?0.406])std?=?np.array([0.229,?0.224,?0.225])input?=?std?*?input?+?meaninput?=?np.clip(input,?0,?1)#?display?imagesplt.imshow(input)plt.title(title)plt.show()#?load?a?batch?of?train?image iterator?=?iter(train_dataloader)#?visualize?a?batch?of?train?image inputs,?classes?=?next(iterator) out?=?torchvision.utils.make_grid(inputs[:4]) imshow(out,?title=[class_names[x]?for?x?in?classes[:4]])定義模型
我們使用遷移學習方法,只需要修改最后的輸出即可。
model?=?models.resnet18(pretrained=True) num_features?=?model.fc.in_features model.fc?=?nn.Linear(num_features,?2)?#?binary?classification?(num_of_class?==?2) model?=?model.to(device)criterion?=?nn.CrossEntropyLoss() optimizer?=?optim.SGD(model.parameters(),?lr=0.001,?momentum=0.9)訓練階段
由于ResNet18網絡非常復雜,深,這里只訓練num_epochs = 3
num_epochs?=?3 start_time?=?time.time()for?epoch?in?range(num_epochs):"""?Training??"""model.train()running_loss?=?0.running_corrects?=?0#?load?a?batch?data?of?imagesfor?i,?(inputs,?labels)?in?enumerate(train_dataloader):inputs?=?inputs.to(device)labels?=?labels.to(device)optimizer.zero_grad()outputs?=?model(inputs)_,?preds?=?torch.max(outputs,?1)loss?=?criterion(outputs,?labels)#?get?loss?value?and?update?the?network?weightsloss.backward()optimizer.step()running_loss?+=?loss.item()?*?inputs.size(0)running_corrects?+=?torch.sum(preds?==?labels.data)epoch_loss?=?running_loss?/?len(train_datasets)epoch_acc?=?running_corrects?/?len(train_datasets)?*?100.print('[Train?#{}]?Loss:?{:.4f}?Acc:?{:.4f}%?Time:?{:.4f}s'.format(epoch,?epoch_loss,?epoch_acc,?time.time()?-?start_time))"""?Validation"""model.eval()with?torch.no_grad():running_loss?=?0.running_corrects?=?0for?inputs,?labels?in?val_dataloader:inputs?=?inputs.to(device)labels?=?labels.to(device)outputs?=?model(inputs)_,?preds?=?torch.max(outputs,?1)loss?=?criterion(outputs,?labels)running_loss?+=?loss.item()?*?inputs.size(0)running_corrects?+=?torch.sum(preds?==?labels.data)epoch_loss?=?running_loss?/?len(val_datasets)epoch_acc?=?running_corrects?/?len(val_datasets)?*?100.print('[Validation?#{}]?Loss:?{:.4f}?Acc:?{:.4f}%?Time:?{:.4f}s'.format(epoch,?epoch_loss,?epoch_acc,?time.time()?-?start_time))「保存訓練好的模型文件」
save_path?=?'face_gender_classification_transfer_learning_with_ResNet18.pth' torch.save(model.state_dict(),?save_path)「訓練好的模型文件加載」
model?=?models.resnet18(pretrained=True) num_features?=?model.fc.in_features model.fc?=?nn.Linear(num_features,?2)? model.load_state_dict(torch.load(save_path)) model.to(device)model.eval() start_time?=?time.time()with?torch.no_grad():running_loss?=?0.running_corrects?=?0for?i,?(inputs,?labels)?in?enumerate(val_dataloader):inputs?=?inputs.to(device)labels?=?labels.to(device)outputs?=?model(inputs)_,?preds?=?torch.max(outputs,?1)loss?=?criterion(outputs,?labels)running_loss?+=?loss.item()?*?inputs.size(0)running_corrects?+=?torch.sum(preds?==?labels.data)if?i?==?0:print('[Prediction?Result?Examples]')images?=?torchvision.utils.make_grid(inputs[:4])imshow(images.cpu(),?title=[class_names[x]?for?x?in?labels[:4]])images?=?torchvision.utils.make_grid(inputs[4:8])imshow(images.cpu(),?title=[class_names[x]?for?x?in?labels[4:8]])epoch_loss?=?running_loss?/?len(val_datasets)epoch_acc?=?running_corrects?/?len(val_datasets)?*?100.print('[Validation?#{}]?Loss:?{:.4f}?Acc:?{:.4f}%?Time:?{:.4f}s'.format(epoch,?epoch_loss,?epoch_acc,?time.time()?-?start_time))在最后的測試結果中,ACC達到了97,但是模型太復雜,運行太慢了,在項目中往往不可取。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》課件合集 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 播放器之争:VLC VS SmartPl
- 下一篇: 索尼首次展示其Airpeak电影摄制无人