Pytorch:使用Alexnet网络实现CIFAR10分类
全部代碼: https://github.com/SPECTRELWF/pytorch-cnn-study
網絡介紹:
Alexnet網絡是CV領域最經典的網絡結構之一了,在2012年橫空出世,并在當年奪下了不少比賽的冠軍,下面是Alexnet的網絡結構:
網絡結構較為簡單,共有五個卷積層和三個全連接層,原文作者在訓練時使用了多卡一起訓練,具體細節可以閱讀原文得到。
Alexnet文章鏈接:http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
作者在網絡中使用了Relu激活函數和Dropout等方法來防止過擬合,更多細節看文章。
數據集介紹
使用的是CIFAR10數據集,數據集的簡介可看我的另一篇文章:
http://liuweifeng.top:8090/archives/python%E8%AF%BB%E5%8F%96cifar10%E6%95%B0%E6%8D%AE%E9%9B%86%E5%B9%B6%E5%B0%86%E6%95%B0%E6%8D%AE%E9%9B%86%E8%BD%AC%E6%8D%A2%E4%B8%BApng%E6%A0%BC%E5%BC%8F%E5%AD%98%E5%82%A8
定義網絡結構
就按照網絡結構圖中一層一層的定義就行,其中第1,2,5層卷積層后面接有Max pooling層和Relu激活函數,五層卷積之后得到圖像的特征表示,送入全連接層中進行分類。
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/2 下午3:25import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, width_mult=1):super(AlexNet, self).__init__()# 定義每一個就卷積層self.layer1 = nn.Sequential(# 卷積層 #輸入圖像為1*28*28nn.Conv2d(3, 32, kernel_size=3, padding=1),# 池化層nn.MaxPool2d(kernel_size=2, stride=2), # 池化層特征圖通道數不改變,每個特征圖的分辨率變小# 激活函數Relunn.ReLU(inplace=True),)self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.MaxPool2d(kernel_size=2, stride=2),nn.ReLU(inplace=True),)self.layer3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),)self.layer4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1),)self.layer5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.MaxPool2d(kernel_size=3, stride=2),nn.ReLU(inplace=True),)# 定義全連接層self.fc1 = nn.Linear(256 * 3 * 3, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 10)# 對應十個類別的輸出def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)# print(x.shape)x = x.view(-1, 256 * 3 * 3)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x訓練
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/4 下午12:59import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from alexnet import AlexNet from utils import plot_curve from dataload.cifar10_dataload import CIFAR10_dataset # 定義使用GPU from torch.utils.data import DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 設置超參數 epochs = 50 batch_size = 256 lr = 0.01transform = transforms.Compose([transforms.Resize([32,32]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),])train_dataset = CIFAR10_dataset(r'/home/lwf/code/pytorch學習/alexnet-CIFAR10/dataset/train',transform=transform) # print(train_dataset[0]) train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle = True,) net = AlexNet().cuda(device) loss_func = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9)train_loss = [] for epoch in range(epochs):sum_loss = 0for batch_idx,(x,y) in enumerate(train_loader):x = x.to(device)y = y.to(device)pred = net(x)optimizer.zero_grad()loss = loss_func(pred,y)loss.backward()optimizer.step()sum_loss += loss.item()train_loss.append(loss.item())print(["epoch:%d , batch:%d , loss:%.3f" %(epoch,batch_idx,loss.item())])torch.save(net.state_dict(), '模型地址')plot_curve(train_loss)使用交叉熵損失函數和SGD優化器來訓練網絡,訓練后保存模型至本地。
測試準確率
![image.png]
預測單張圖片:
總結
以上是生活随笔為你收集整理的Pytorch:使用Alexnet网络实现CIFAR10分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python读取CIFAR10数据集并将
- 下一篇: pytorch:ResNet50做新冠肺