Pytorch机器学习/深度学习代码笔记
代碼步驟筆記
- 導(dǎo)入模塊
- 設(shè)置參數(shù)
- 數(shù)據(jù)預(yù)處理
- 定義數(shù)據(jù)集
- 1.Dataset
- 2.ImageFolder
- 加載數(shù)據(jù)集
- DataLoader
- torchvision--數(shù)據(jù)預(yù)處理要使用的庫(kù)
- torchvision.datasets
- torchvision.models
- torchvision.transforms
- 訓(xùn)練網(wǎng)絡(luò)參數(shù)
- 訓(xùn)練前的準(zhǔn)備
- 設(shè)置指定的訓(xùn)練設(shè)備(GPU、CPU)
- 定義損失函數(shù)
- 定義優(yōu)化器
- 訓(xùn)練過(guò)程
- 驗(yàn)證/測(cè)試過(guò)程
- 運(yùn)行
導(dǎo)入模塊
import torch from tensorboardX import SummaryWriter //可視化設(shè)置參數(shù)
batch_size=64 works=4 epochs=20 train_path="train" val_path="val"數(shù)據(jù)預(yù)處理
流程:先定義數(shù)據(jù)集,再將定義的數(shù)據(jù)集導(dǎo)入數(shù)據(jù)載入器(Dataloader)來(lái)讀取數(shù)據(jù)。
定義數(shù)據(jù)集有兩種方式,一種是自定義Dataset包裝類,和DataLoader類一樣,它是torch.utils.data的里的一個(gè)類,另一種是直接調(diào)用ImageFolder函數(shù),它是torchvision.datasets里的函數(shù)。
定義數(shù)據(jù)集
1.Dataset
Dataset是一個(gè)抽象類,可以自定義數(shù)據(jù)集,為了能夠方便的讀取,需要將要使用的數(shù)據(jù)包裝為Dataset類。
自定義的Dataset需要繼承它并且實(shí)現(xiàn)兩個(gè)成員方法:
1.__getitem__():該方法定義用索引(0到len(self))獲取一條數(shù)據(jù)或一個(gè)樣本。
2.__len__()方法返回?cái)?shù)據(jù)集的總長(zhǎng)度。
模板如下:
2.ImageFolder
ImageFolder假設(shè)所有的文件按文件夾保存,每個(gè)文件夾下存儲(chǔ)同一個(gè)類別的圖片,文件夾名為類名,其構(gòu)造函數(shù)如下:
import torchvision.datasets ImageFolder(root, transform=None, target_transform=None, loader=default_loader)各參數(shù)含義:
root:在root指定的路徑下尋找圖片
transform:對(duì)PIL Image進(jìn)行的轉(zhuǎn)換操作,transform的輸入是使用loader讀取圖片的返回對(duì)象
target_transform:對(duì)label的轉(zhuǎn)換
loader:給定路徑后如何讀取圖片,默認(rèn)讀取為RGB格式的PIL Image對(duì)象
label:按照文件夾名順序排序后存成字典,即{類名:類序號(hào)(從0開(kāi)始)}
舉例如下:
import torchvision.datasets #此處transform需自己定義(見(jiàn)下面torchvision.transforms),其他參數(shù)為默認(rèn)值 train_data=torchvision.datasets.ImageFolder(root=train_path,transform=transform)加載數(shù)據(jù)集
DataLoader
DataLoader是一個(gè)數(shù)據(jù)加載器類,實(shí)現(xiàn)了對(duì)數(shù)據(jù)集進(jìn)行隨機(jī)采樣和多輪次迭代的功能。在訓(xùn)練過(guò)程中,可以非常方便地實(shí)現(xiàn)多輪次小批量隨機(jī)梯度下降訓(xùn)練。
常用參數(shù)有:Dataset數(shù)據(jù)集實(shí)例,batch_size(每個(gè)batch的大小,shuffle(是否進(jìn)行攪亂操作),num_workers(加載數(shù)據(jù)的時(shí)候使用幾個(gè)子進(jìn)程),返回一個(gè)可迭代對(duì)象。
詳細(xì)有關(guān)參數(shù)見(jiàn)博客:PyTorch 中的數(shù)據(jù)類型 torch.utils.data.DataLoader
torchvision–數(shù)據(jù)預(yù)處理要使用的庫(kù)
torchvision是Pytorch中專門用來(lái)處理圖像的庫(kù)。
提供了常用圖片數(shù)據(jù)集(datasets);
訓(xùn)練好的模型(models);
一般的圖像轉(zhuǎn)換操作類(transforms),
torchvision.datasets
torchvision.datasets可以理解為PyTorch團(tuán)隊(duì)自定義的dataset,這些dataset幫我們提前處理好了很多的圖片數(shù)據(jù)集,我們拿來(lái)就可以直接使用:
- MNIST
- COCO
- Captions
- Detection
- LSUN
- ImageFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- PhotoTour
以上我們可以直接用(其他的只能通過(guò)自己自定義數(shù)據(jù)集),示例如下:
torchvision.models
torchvision提供了訓(xùn)練好的模型,可以加載后直接使用(見(jiàn)下面代碼),或者在進(jìn)行遷移學(xué)習(xí)torchvision.models模塊的子模塊中包含以下模型結(jié)構(gòu):
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
torchvision.transforms
transform模塊提供了一般的圖像轉(zhuǎn)換操作類,用作數(shù)據(jù)處理和數(shù)據(jù)增強(qiáng)。
主要提供了對(duì)PIL Image對(duì)象和Tensor對(duì)象的常用操作。
對(duì)PIL Image對(duì)象的常用操作有:
- Resize:調(diào)整圖片尺寸
- CenterCrop、RandomCrop、RandomSizedCrop:裁剪圖片
- Pad:填充
- ToTensor:將PIL Image對(duì)象轉(zhuǎn)成Tensor,會(huì)自動(dòng)將[0,255]歸一化至[0,1]
對(duì)Tensor對(duì)象的常用操作有:
- Normalize:標(biāo)準(zhǔn)化,即減均值,除以標(biāo)準(zhǔn)差
- ToPILImage:將Tensor轉(zhuǎn)為PIL Image對(duì)象。
詳細(xì)有關(guān)transforms的用法見(jiàn)博客:PyTorch 學(xué)習(xí)筆記(三):transforms的二十二個(gè)方法
訓(xùn)練網(wǎng)絡(luò)參數(shù)
訓(xùn)練前的準(zhǔn)備
設(shè)置指定的訓(xùn)練設(shè)備(GPU、CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')定義損失函數(shù)
torch.nn模塊中定義了很多標(biāo)準(zhǔn)地?fù)p失函數(shù)。
import torch.nn as nn xentropy=nn.CrossEntropyLoss() #此處定義一個(gè)交叉熵?fù)p失函數(shù)對(duì)象,該對(duì)象可以調(diào)用backward()方法實(shí)現(xiàn)誤差反向傳播。定義優(yōu)化器
torch.optim模塊提供了很多優(yōu)化算法類,
比如:torch.optim.SGD,torch.optim.Adam,torch.optim.RMSprop。這里以SGD為例。
詳細(xì)參數(shù)見(jiàn)博客:torch.optim.SGD()各參數(shù)的解釋
訓(xùn)練過(guò)程
神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程的一步迭代包含四個(gè)主要步驟:
- 前向運(yùn)算,計(jì)算給定輸入的預(yù)測(cè)結(jié)果
- 計(jì)算損失函數(shù)值
- 反向傳播(BP),計(jì)算參數(shù)梯度(計(jì)算之前要先梯度清零)
- 使用梯度下降法更新參數(shù)值
詳細(xì)代碼如下:
def train(net,optimizer,loss_fn,num_epoch,data_loader,device): '''參數(shù)分別為網(wǎng)絡(luò)模型、損失函數(shù)(對(duì)應(yīng)之前的xentropy)、epoch總次數(shù)、數(shù)據(jù)加載器、訓(xùn)練設(shè)備'''net.train() #進(jìn)入訓(xùn)練模型for epoch in range(num_epoch):print('Epoch {}/{}'.format(epoch+1, num_epochs))running_loss=0running_corrects=0for i,data in enumerate(data_loader):inputs=data[0].to(device) #輸入labels=data[1].to(device) #真實(shí)值標(biāo)簽#下面優(yōu)化過(guò)程optimizer.zero_grad() #先把前一步的梯度清除,設(shè)置梯度值為0outputs=net(inputs) #前向運(yùn)算,計(jì)算網(wǎng)絡(luò)模型在inputs上的輸出outputsloss=loss_fn(outputs,labels) #計(jì)算損失函數(shù)值loss.backward() #進(jìn)行反向傳播,計(jì)算梯度optimizer.step() #使用優(yōu)化器的step()方法,進(jìn)行梯度下降,更新模型參數(shù)#可以輸出兩種loss,loss為每次迭代的loss,running_loss為每個(gè)epoch的loss,之后再取平均值。running_loss+=loss.item() #計(jì)算每個(gè)epoch的loss總值_, preds = torch.max(outputs, 1)running_corrects += torch.sum(preds == labels).item()epoch_loss=running__loss/len(train_data) #計(jì)算每個(gè)epoch的平均lossepoch_acc = running_corrects / len(train_data)print('{} Loss: {:.4f} Acc: {:.4f}'.format('train', epoch_loss, epoch_acc))驗(yàn)證/測(cè)試過(guò)程
測(cè)試和驗(yàn)證集過(guò)程不用反向傳播,也不用更新梯度。
def evaluate(net,loss_fn,data_load,device):net.eval() #進(jìn)入模型評(píng)估模式,驗(yàn)證和測(cè)試都是這個(gè)running_loss=0correct=0.0total=0for data in data_loader:inputs=data[0].to(device) #輸入labels=data[1].to(device) #真實(shí)值標(biāo)簽with torch.no_grad(): outputs=net(inputs)loss=loss_fn(outputs,labels)running_loss+=loss.item()_,predicted=torch.max(outputs.data,1)total+=labels.size(0) #另一種計(jì)算總數(shù)的方法correct+=(predicted==labels).sum().item() #計(jì)算預(yù)測(cè)對(duì)的數(shù)epoch_loss = running_loss/len(val_data)acc=correct/total #計(jì)算準(zhǔn)確率print('{} Loss: {:.4f} Acc: {:.4f}'.format('valid', epoch_loss, acc))運(yùn)行
有兩種方式:
- 1.設(shè)立一個(gè)主函數(shù)main(),將for epoch in epochs:以及train函數(shù)和test函數(shù)放到main()里運(yùn)行就可以了。
- 2.將for epoch in epochs:和test函數(shù)放入train函數(shù),再直接運(yùn)行train()函數(shù)就可以了。
完整代碼實(shí)例:pytorch實(shí)現(xiàn)圖像分類代碼實(shí)例
總結(jié)
以上是生活随笔為你收集整理的Pytorch机器学习/深度学习代码笔记的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 国内免备案主机(免备案主机哪里好)
- 下一篇: linux的挂载点是什么意思(linux