Pytorch快速搭建并训练CNN模型?
目錄
1、數(shù)據(jù)處理模塊搭建
?2、模型構(gòu)建
3、開始訓(xùn)練
4、評估模型
?5、使用模型進行預(yù)測
6、保存模型
1、數(shù)據(jù)處理模塊搭建
這里需要根據(jù)自己的數(shù)據(jù)集進行選擇合適的方法,這里就以圖像分類作為一個例子來說明。
通常有兩種方法:
(1)采用torchvision中的datasets.ImageFolder來讀取圖像,然后采用torch.utils.data.DataLoader加載;
Ps:這種情況一般是想要讀取一自己在一個文件夾中的數(shù)據(jù)作為數(shù)據(jù)集 具體的形式如下: dataset/cat/0.jpg1.jpgdog/0.jpg1.jpg-------------------------- 這種情況使用ImageFolder就比較方便(2)繼承torch.utils.data.Dataset來實現(xiàn)用戶自定義,然后采用torch.utils.data.DataLoader加載;
torch.utils.data.Dataset 是一個表示數(shù)據(jù)集的抽象類。任何自定義的數(shù)據(jù)集都需要繼承這個類并覆寫相關(guān)方法。Pytorch提供兩種數(shù)據(jù)集: Map式數(shù)據(jù)集 Iterable式數(shù)據(jù)集對于Map式數(shù)據(jù)集處理方式:重寫getitem(self, index),len(self) 兩個內(nèi)建方法,用來表示從索引到樣本的映射(Map).當(dāng)使用dataset[idx]命令時,可以在你的硬盤中讀取你的數(shù)據(jù)集中第idx張圖片以及其標(biāo)簽(如果有的話);len(dataset)則會返回這個數(shù)據(jù)集的容量。上述參考:https://zhuanlan.zhihu.com/p/105507334自定義模塊可以參考:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).# 這里需要注意的是,第一步:read one data,是一個datapassdef __len__(self):# You should change 0 to the total size of your dataset.return 0參考一個實例:
from torch.utils import data import numpy as np from PIL import Image# 參考:https://zhuanlan.zhihu.com/p/105507334 class face_dataset(data.Dataset):def __init__(self):# 數(shù)據(jù)集的路徑self.file_path = './data/faces/'# 對應(yīng)的數(shù)據(jù)集和標(biāo)簽,這里是保存在txt文件中的,也有的是json文件,或者csv文件等# 根據(jù)自己的數(shù)據(jù)集情況而定f = open("final_train_tag_dict.txt","r")self.label_dict = eval(f.read())f.close()def __getitem__(self,index):"""通過index返回對應(yīng)的img和label"""label = list(self.label_dict.values())[index-1]img_id = list(self.label_dict.keys())[index-1]img_path = self.file_path+str(img_id)+".jpg"img = np.array(Image.open(img_path))return img,labeldef __len__(self):# 返回整個數(shù)據(jù)集的數(shù)量return len(self.label_dict)在這里我采用第一種形式,因為我采用的數(shù)據(jù)集是下面這種形式:
每個文件對應(yīng)一個類別,如果你采用的數(shù)據(jù)集是給定了一個image_label.txt或者image_label.csv,則采用第二種數(shù)據(jù)處理方法比較方便;
第一種方法的實現(xiàn)代碼如下:
from torch.utils.data import Dataset,DataLoader from torchvision import transforms,datasets# 1、Data augmentation # https://pytorch.org/vision/stable/transforms.html # 數(shù)據(jù)增強部分可根據(jù)自己的情況選擇,可以參考官方代碼 transforms_train = transforms.Compose([transforms.ToTensor(),transforms.ColorJitter(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))]) # valid不需要數(shù)據(jù)增強 transforms_valid = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])# 2、load dataset ds_train = datasets.ImageFolder("../data/train/",transform=transforms_train,target_transform=lambda t:torch.tensor([t]).float()) ds_valid = datasets.ImageFolder("../data/test/",transform=transforms_valid,target_transform=lambda t:torch.tensor([t]).float())官方文檔:https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolderhttps://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolder
torchvision.datasets.ImageFolder(root:?str,?transform:?Optional[Callable]?=?None,?target_transform:?Optional[Callable]?=?None,?loader:?Callable[[str],?Any]?=?<function?default_loader>,?is_valid_file:?Optional[Callable[[str],?bool]]?=?None)
Parameters:1、root (string) – Root directory path.->數(shù)據(jù)集地址2、transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop3、target_transform (callable, optional) – A function/transform that takes in the target and transforms it.主要是處理對應(yīng)的圖像標(biāo)簽4、loader (callable, optional) – A function to load an image given its path.5、is_valid_file – A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) 檢查數(shù)據(jù)集中圖像是否損壞Returns:通過:__getitem__(index: int) → 得到:Tuple[Any, Any]1、(sample, target) where target is class_index of the target class.經(jīng)過ImageLoader之后的數(shù)據(jù)具體是什么格式?
?從上圖可以看出返回的samples中是一個元組(圖像的地址,圖像的標(biāo)簽);
targets對應(yīng)每張圖像的標(biāo)簽,classes所有數(shù)據(jù)的類別,class_to_idx類別索引,extensions圖像支持的擴張名等
# 查看數(shù)據(jù)集中的類別 print(ds_train) print(ds_valid.classes) # 每個類別對應(yīng)的標(biāo)簽 print(ds_valid.class_to_idx)經(jīng)過ImageLoader處理后,還需要經(jīng)過DataLoader進一步處理:
# 通過DataLoader加載ImageFolder # 這里的num_workers為了避免出錯,盡量設(shè)置為0 dl_train = DataLoader(ds_train,batch_size=50,shuffle=True,num_workers=0) dl_valid = DataLoader(ds_valid,batch_size=50,shuffle=True,num_workers=0)注意:這個num_workers如果設(shè)置為其他數(shù)字,剛開始可能沒問題,但是后續(xù)會可能會出現(xiàn)問題,不妨設(shè)置為0;
官方文檔:
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoaderhttps://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
介紹:
DataLoader.:Combines a dataset and a sampler, and provides an iterable over the given dataset.
包括一個數(shù)據(jù)集和一個采樣器,并且提供一個給定數(shù)據(jù)集的可迭代對象;
The?DataLoader?supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
DataLoader支持的格式比較多,本次采用的是map-style;
看一下經(jīng)過DataLoader之后的數(shù)據(jù)形式:
?查看一下數(shù)據(jù)集中的部分樣本:
import matplotlib.pyplot as plt plt.figure(figsize=(5,5)) for i in range(9):# ds_train[i]也可以img,label = ds_valid[i]# 圖像是b*c*w*h->b*w*h*cimg = img.permute(1,2,0)ax = plt.subplot(3,3,i+1)ax.imshow(img.numpy())ax.set_title("label = %d"%label.item(),fontsize=8)ax.set_xticks([])ax.set_yticks([]) plt.show()?2、模型構(gòu)建
(1)使用torch.nn.Sequential按層順序構(gòu)建模型;
(2)繼承torch.nn.Module基類構(gòu)建模型;
(3)繼承torch.nn.Module基類構(gòu)建并輔助應(yīng)用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict);
nn.Sequential案例
# Using Sequential to create a small model. When `model` is run, # input will first be passed to `Conv2d(1,20,5)`. The output of # `Conv2d(1,20,5)` will be used as the input to the first # `ReLU`; the output of the first `ReLU` will become the input # for `Conv2d(20,64,5)`. Finally, the output of # `Conv2d(20,64,5)` will be used as input to the second `ReLU` model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())# Using Sequential with OrderedDict. This is functionally the # same as the above code model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))而對于nn.Module通過官方的一個example:
它是所有神經(jīng)網(wǎng)絡(luò)模塊的基類,自己定義的模型應(yīng)該繼承這個類。
同時該模塊還可以包含其他模塊,允許將它們嵌套在樹結(jié)構(gòu)中。
import torch.nn as nn import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))本文采用繼承nn.Module創(chuàng)建model
class Image_Net(nn.Module):def __init__(self):super(Image_Net,self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3)self.pool = nn.MaxPool2d(kernel_size=2,stride=2)self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5)self.dropout = nn.Dropout2d(p=0.2)self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))self.flatten = nn.Flatten()self.linear1 = nn.Linear(64,32)self.relu = nn.ReLU()self.linear2 = nn.Linear(32,1)self.sigmoid = nn.Sigmoid()def forward(self,x):x = self.conv1(x)x = self.pool(x)x = self.conv2(x)x = self.dropout(x)x = self.adaptive_pool(x)x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)y = self.sigmoid(x)return y# 實例化 net = Image_Net() print(net)3、開始訓(xùn)練
首先設(shè)置一些訓(xùn)練參數(shù)
import pandas as pd # 其他指標(biāo)可以查看sklearn.metrics from sklearn.metrics import roc_auc_score model = Image_Net() model.optimizer = torch.optim.SGD(model.parameters(),lr=0.01) model.loss_func = torch.nn.BCELoss() model.metric_func = lambda y_pred,y_true:roc_auc_score(y_true.data.numpy(),y_pred.data.numpy()) model.metric_name = "auc"下面采用函數(shù)式訓(xùn)練循環(huán)
首先創(chuàng)建train模塊
def train(model,features,labels):""":param model: :param features: :param labels: :return: loss & metric"""# 訓(xùn)練模式,dropout層發(fā)生作用model.train()# 梯度清零model.optimizer.zero_grad()# 正向傳播求損失predictions = model(features)# 計算損失loss = model.loss_func(predictions,labels)# metric計算,這里選擇的是AUCmetric = model.metric_func(predictions,labels)# 反向傳播求梯度loss.backward()model.optimizer.step()return loss.item(),metric.item()然后創(chuàng)建valid模塊:
def valid(model,features,labels):"""因為只是驗證所以不對模型的參數(shù)進行更新,只需要輸出對應(yīng)的結(jié)果就行:param model: :param features: :param labels: :return: loss & metric"""# 預(yù)測模式,dropout層不發(fā)生作用model.eval()predictions = model(features)loss = model.loss_func(predictions,labels)metric = model.metric_func(predictions,labels)return loss.item(),metric.item()設(shè)置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # 移動模型到cuda還有一個重要的事情,就是GPU用完一定要及時釋放:在上述設(shè)置GPU的情況下,增加下述代碼:
torch.cuda.empty_cache()tensorflow中清理顯存的方法:解決tensorflow占用GPU顯存問題?
完整的訓(xùn)練代碼如下:
import datetime def train_model(model, epochs, dl_train, dl_valid, log_step_freq):metric_name = model.metric_name# 用于記錄訓(xùn)練過程中的loss和metricdfhistory = pd.DataFrame(columns=["epoch", "loss", metric_name, "val_loss", "val_" + metric_name])print("Start Training...")nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("==========" * 8 + "%s" % nowtime)for epoch in range(1, epochs + 1):# 1,訓(xùn)練循環(huán)-------------------------------------------------loss_sum = 0.0metric_sum = 0.0step = 1for step, (features, labels) in enumerate(dl_train, 1):# train模塊,也可以直接放在這里loss, metric = train(model, features, labels)# 打印batch級別日志loss_sum += lossmetric_sum += metric# 設(shè)置打印freqif step % log_step_freq == 0:print(("[step = %d] loss: %.3f, " + metric_name + ": %.3f") %(step, loss_sum / step, metric_sum / step))# 2,驗證循環(huán)-------------------------------------------------val_loss_sum = 0.0val_metric_sum = 0.0val_step = 1for val_step, (features, labels) in enumerate(dl_valid, 1):# valid模塊val_loss, val_metric = valid(model, features, labels)val_loss_sum += val_lossval_metric_sum += val_metric# 3,記錄日志-------------------------------------------------info = (epoch, loss_sum / step, metric_sum / step,val_loss_sum / val_step, val_metric_sum / val_step)dfhistory.loc[epoch - 1] = info# 打印epoch級別日志print(("\nEPOCH = %d, loss = %.3f," + metric_name + \" = %.3f, val_loss = %.3f, " + "val_" + metric_name + " = %.3f")% info)nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("\n" + "==========" * 8 + "%s" % nowtime)print('Finished Training...')return dfhistory訓(xùn)練實例:
epochs = 25 dfhistory = train_model(model,epochs,dl_train,dl_valid,50) Start Training... ================================================================================2022-01-19 10:39:33 [step = 50] loss: 0.662, auc: 0.699 [step = 100] loss: 0.627, auc: 0.747 [step = 150] loss: 0.605, auc: 0.762 [step = 200] loss: 0.593, auc: 0.770EPOCH = 1, loss = 0.593,auc = 0.770, val_loss = 0.514, val_auc = 0.839================================================================================2022-01-19 10:39:43 [step = 50] loss: 0.541, auc: 0.805 [step = 100] loss: 0.539, auc: 0.806 [step = 150] loss: 0.531, auc: 0.813 [step = 200] loss: 0.524, auc: 0.819......4、評估模型
直接print(dfhistory)即可;
def plot_metric(dfhistory,metric,name):""":param dfhistory: 訓(xùn)練的info:param metric: 指定訓(xùn)練的哪個指標(biāo):return: 返回對應(yīng)的訓(xùn)練曲線"""train_metrics = dfhistory[metric]val_metrics = dfhistory['val_'+metric]epochs = range(1,len(train_metrics)+1)plt.plot(epochs,train_metrics,"bo--")plt.plot(epochs,val_metrics,"ro-")plt.title("Training and validation "+metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric, 'val_'+metric])# saveplt.savefig("figure/"+name+".jpg")plt.show()這里將plot_metric放在utils.py中;?
from utils import plot_metric plot_metric(dfhistory,"loss",name="image_train_loss") plot_metric(dfhistory,"auc",name="image_train_auc") image_train_auc.jpg image_train_loss.jpg?5、使用模型進行預(yù)測
def predict(model,dl):model.eval()result = torch.cat([model.forward(t[0]) for t in dl])return (result.data) # 預(yù)測概率 y_pred_probs = predict(model,dl_valid) print("y_pred_probs:",y_pred_probs) # 預(yù)測類別 y_pred = torch.where(y_pred_probs>0.5,torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs)) print(y_pred)6、保存模型
淺談pytorch 模型 .pt, .pth, .pkl的區(qū)別及模型保存方式https://www.jb51.net/article/187269.htm采用torch.save保存模型參數(shù):
https://pytorch.org/docs/stable/generated/torch.save.html?highlight=save#torch.savehttps://pytorch.org/docs/stable/generated/torch.save.html?highlight=save#torch.save
torch.save(model.state_dict(),"model/model_parameter_image.pkl") net_clone = Image_Net() net_clone.load_state_dict(torch.load("model/model_parameter_image.pkl")) # test predict(net_clone,dl_valid)后續(xù)會增加onnx模型部署!
總結(jié)
以上是生活随笔為你收集整理的Pytorch快速搭建并训练CNN模型?的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: ant design vue 1.7.8
- 下一篇: 易云捷讯联合北京菜篮子配送有限公司共建“
