pytorch dropout_PyTorch初探MNIST数据集
前言:
本文主要描述了如何使用現(xiàn)在熱度和關(guān)注度比較高的Pytorch(深度學(xué)習(xí)框架)構(gòu)建一個(gè)簡(jiǎn)單的卷積神經(jīng)網(wǎng)絡(luò),并對(duì)MNIST數(shù)據(jù)集進(jìn)行了訓(xùn)練和測(cè)試。MNIST數(shù)據(jù)集是一個(gè)28*28的手寫數(shù)字圖片集合,使用測(cè)試集來(lái)驗(yàn)證訓(xùn)練出的模型對(duì)手寫數(shù)字的識(shí)別準(zhǔn)確率。
PyTorch資料:
PyTorch的官方文檔鏈接:PyTorch documentation,在這里不僅有 API的說(shuō)明還有一些經(jīng)典的實(shí)例可供參考。
PyTorch官網(wǎng)論壇:vision,里面會(huì)有很大資料分享和一些熱門問題的解答。
PyTorch搭建神經(jīng)網(wǎng)絡(luò)實(shí)踐:
在一開始導(dǎo)入需要導(dǎo)入PyTorch的兩個(gè)核心庫(kù)文件torch和torchvision,這兩個(gè)庫(kù)基本包含了PyTorch會(huì)用到的許多方法和函數(shù)
import其中值得一提的是torchvision的datasets可以很方便的自動(dòng)下載數(shù)據(jù)集,這里使用的是MNIST數(shù)據(jù)集。另外的COCO,ImageNet,CIFCAR等數(shù)據(jù)集也可以很方的下載并使用,導(dǎo)入命令也非常簡(jiǎn)單
data_train = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="./data/",transform = transform,train = False)root指定了數(shù)據(jù)集存放的路徑,transform指定導(dǎo)入數(shù)據(jù)集時(shí)需要進(jìn)行何種變換操作,train設(shè)置為True說(shuō)明導(dǎo)入的是訓(xùn)練集合,否則為測(cè)試集合。
transform里面還有很多好的方法,可以用在圖片資源較少的數(shù)據(jù)集做Data Argumentation操作,這里只是做了個(gè)簡(jiǎn)單的Tensor格式轉(zhuǎn)換和Batch Normalize
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])數(shù)據(jù)下載完成后還需要做數(shù)據(jù)裝載操作
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)batch_size設(shè)置了每批裝載的數(shù)據(jù)圖片為64個(gè),shuffle設(shè)置為True在裝載過(guò)程中為隨機(jī)亂序
下圖為一個(gè)batch數(shù)據(jù)集(64張圖片)的顯示,可以看出來(lái)都為28*28的1維圖片
MNIST數(shù)據(jù)集圖片預(yù)覽完成數(shù)據(jù)裝載后就可以構(gòu)建核心程序了,這里構(gòu)建的是一個(gè)包含了卷積層和全連接層的神經(jīng)網(wǎng)絡(luò),其中卷積層使用torch.nn.Conv2d來(lái)構(gòu)建,激活層使用torch.nn.ReLU來(lái)構(gòu)建,池化層使用torch.nn.MaxPool2d來(lái)構(gòu)建,全連接層使用torch.nn.Linear來(lái)構(gòu)建
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(stride=2,kernel_size=2))self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x其中定義了torch.nn.Dropout(p=0.5)防止模型的過(guò)擬合
forward函數(shù)定義了前向傳播,其實(shí)就是正常卷積路徑。首先經(jīng)過(guò)self.conv1(x)卷積處理,然后進(jìn)行x.view(-1, 14*14*128)壓縮扁平化處理,最后通過(guò)self.dense(x)全連接進(jìn)行分類
之后就是對(duì)Model對(duì)象進(jìn)行調(diào)用,然后定義loss計(jì)算使用交叉熵,優(yōu)化計(jì)算使用Adam自動(dòng)化方式,最后就可以開始訓(xùn)練了
model = Model() cost = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters())在訓(xùn)練前可以查看神經(jīng)網(wǎng)絡(luò)架構(gòu)了,print輸出顯示如下
Model ((conv1): Sequential ((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU ()(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU ()(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)))(dense): Sequential ((0): Linear (25088 -> 1024)(1): ReLU ()(2): Dropout (p = 0.5)(3): Linear (1024 -> 10)) )定義訓(xùn)練次數(shù)為5次,開始跑神經(jīng)網(wǎng)絡(luò),訓(xùn)練完成后輸入測(cè)試集合得到的結(jié)果如下
Epoch 0/5 ---------- Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600 Epoch 1/5 ---------- Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200 Epoch 2/5 ---------- Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700 Epoch 3/5 ---------- Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600 Epoch 4/5 ---------- Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300從結(jié)果上看還不錯(cuò),訓(xùn)練準(zhǔn)確率最高達(dá)到了99.73%,測(cè)試最高準(zhǔn)確率為98.96%。結(jié)果有輕微的過(guò)擬合跡象,如果使用更加健壯的卷積模型測(cè)試集會(huì)取得更加好的結(jié)果。
隨機(jī)對(duì)幾張測(cè)試集的圖片進(jìn)行預(yù)測(cè),并做可視化展示
Predict Label is: [3, 4, 9, 3] Real Label is: [3, 4, 9, 3]訓(xùn)練完成后還可以保存訓(xùn)練得到的參數(shù),方便下次導(dǎo)入后可供直接使用
torch.save(model.state_dict(), "model_parameter.pkl")完整代碼鏈接:JaimeTang/Pytorch-and-mnist(model_parameter.pkl文件較大未做上傳)
微信公眾號(hào):PyMachine
總結(jié)
以上是生活随笔為你收集整理的pytorch dropout_PyTorch初探MNIST数据集的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 第11章 Tkinter 概述
- 下一篇: vue+node实现中间层同步调用接口