使用PYTORCH复现ALEXNET实现MNIST手写数字识别
生活随笔
收集整理的這篇文章主要介紹了
使用PYTORCH复现ALEXNET实现MNIST手写数字识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
網絡介紹:
Alexnet網絡是CV領域最經典的網絡結構之一了,在2012年橫空出世,并在當年奪下了不少比賽的冠軍,下面是Alexnet的網絡結構:
網絡結構較為簡單,共有五個卷積層和三個全連接層,原文作者在訓練時使用了多卡一起訓練,具體細節可以閱讀原文得到。
Alexnet文章鏈接:http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
作者在網絡中使用了Relu激活函數和Dropout等方法來防止過擬合,更多細節看文章。
數據集介紹
使用的是MNIST手寫數字識別數據集,torchvision中自帶有數據集的下載地址。
定義網絡結構
就按照網絡結構圖中一層一層的定義就行,其中第1,2,5層卷積層后面接有Max pooling層和Relu激活函數,五層卷積之后得到圖像的特征表示,送入全連接層中進行分類。
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/2 下午3:25import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torch.optim as optimclass AlexNet(nn.Module):def __init__(self,width_mult=1):super(AlexNet,self).__init__()#定義每一個就卷積層self.layer1 = nn.Sequential(#卷積層 #輸入圖像為1*28*28nn.Conv2d(1,32,kernel_size=3,padding=1),#池化層nn.MaxPool2d(kernel_size=2,stride=2) , #池化層特征圖通道數不改變,每個特征圖的分辨率變小#激活函數Relunn.ReLU(inplace=True),)self.layer2 = nn.Sequential(nn.Conv2d(32,64,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=2,stride=2),nn.ReLU(inplace=True),)self.layer3 = nn.Sequential(nn.Conv2d(64,128,kernel_size=3,padding=1),)self.layer4 = nn.Sequential(nn.Conv2d(128,256,kernel_size=3,padding=1),)self.layer5 = nn.Sequential(nn.Conv2d(256,256,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3, stride=2),nn.ReLU(inplace=True),)#定義全連接層self.fc1 = nn.Linear(256 * 3 * 3,1024)self.fc2 = nn.Linear(1024,512)self.fc3 = nn.Linear(512,10)#對應十個類別的輸出def forward(self,x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)x = x.view(-1,256*3*3)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x訓練
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/2 下午3:38import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torch.optim as optim import torchvision.transforms as transforms from alexnet import AlexNet import cv2 from utils import plot_image,plot_curve,one_hot #定義使用GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")#設置超參數 epochs = 30 batch_size = 256 lr = 0.01train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),#數據歸一化torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True )test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = 256,shuffle = False )#定義損失函數 criterion = nn.CrossEntropyLoss()#定義網絡 net = AlexNet().to(device)#定義優化器 optimzer = optim.SGD(net.parameters(),lr=lr,momentum = 0.9)#train train_loss = [] for epoch in range(epochs):sum_loss = 0.0for batch_idx,(x,y) in enumerate(train_loader):print(x.shape)x = x.to(device)y = y.to(device)#梯度清零optimzer.zero_grad()pred = net(x)loss = criterion(pred, y)loss.backward()optimzer.step()train_loss.append(loss.item())sum_loss += loss.item()if batch_idx % 100 == 99:print('[%d, %d] loss: %.03f'% (epoch + 1, batch_idx + 1, sum_loss / 100))sum_loss = 0.0 torch.save(net.state_dict(),'/home/lwf/code/pytorch學習/alexnet圖像分類/model/model.pth') plot_curve(train_loss)使用交叉熵損失函數和SGD優化器來訓練網絡,訓練后保存模型至本地。
訓練過程中損失函數的收斂過程:
測試準確率
完整代碼:https://github.com/SPECTRELWF/pytorch-cnn-study/tree/main/Alexnet-MNIST
個人主頁:http://liuweifeng.top:8090/
總結
以上是生活随笔為你收集整理的使用PYTORCH复现ALEXNET实现MNIST手写数字识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MNIST手写数字识别
- 下一篇: 使用PYTORCH复现ALEXNET实现