使用Pytorch快速搭建神经网络模型(附详细注释和讲解)
文章目錄
- 0 前言
- 1 數據讀入
- 2 模型搭建
- 3 模型訓練
- 4 模型測試
- 5 模型保存
- 6 參考博客
0 前言
代碼參考了知乎上“10分鐘快速入門PyTorch”系列,并且附上了詳細的注釋和函數講解。從今天這篇博文開始,我將和大家一起踏上Pytorch的學習道路,希望有問題可以指出!代碼可以直接復制粘貼后運行。
1 數據讀入
torchvision.datasets里面有很多數據類型,里面有官網處理好的數據,比如我們要使用的MNIST數據集(手寫數字數據集),可以通過torchvision.datasets.MNIST()來得到:
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定義超參數 batch_size = 64 learning_rate = 1e-2 num_epochs = 5 # 訓練次數 # 判斷GPU是否可用 use_gpu = torch.cuda.is_available()# 下載訓練集 MNIST 手寫數字訓練集 # 數據是datasets類型的 train_dataset = datasets.FashionMNIST(root='../datasets', train=True, transform=transforms.ToTensor(), download=True)test_dataset = datasets.FashionMNIST(root='../datasets', train=False, transform=transforms.ToTensor()) # 將數據處理成 DataLoader train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 選擇打亂數據 test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 選擇不打亂數據可見每次會返回:數據+標簽
2 模型搭建
這里給出一個通用的模型框架:
# 基本的網絡構建類模板 class net_name(nn.Module):def __init__(self):super(net_name, self).__init__()# 可以添加各種網絡層self.conv1 = nn.Conv2d(3, 10, 3)# 具體每種層的參數可以去查看文檔def forward(self, x):# 定義向前傳播out = self.conv1(x)return out由上述框架搭建我們的神經網絡:
# 定義簡單的前饋神經網絡 class neuralNetwork(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(neuralNetwork, self).__init__() # super() 函數是用于調用父類(超類)的一個方法 # Sequential()表示將一個有序的模塊寫在一起,也就相當于將神經網絡的層按順序放在一起,這樣可以方便結構顯示self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.ReLU(True)) # 表示使用ReLU激活函數self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.ReLU(True))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim),nn.ReLU(True))# 定義向前傳播def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x # 圖片大小是28*28,中間定義了兩個隱藏層大小分別為300和100,最后輸出層為10,10分類問題 model = neuralNetwork(28 * 28, 300, 100, 10) if use_gpu:model = model.cuda() # 現在可以在GPU上跑代碼了criterion = nn.CrossEntropyLoss() # 定義損失函數類型,使用交叉熵 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定義優化器,使用隨機梯度下降3 模型訓練
第一個循環表示每個epoch,接著開始前向傳播,然后計算loss,然后反向傳播,接著優化參數,特別注意的是在每次反向傳播的時候需要將參數的梯度歸零:
# 開始模型訓練 for epoch in range(num_epochs):print('*' * 10)print(f'epoch {epoch+1}')running_loss = 0.0 # 初始值running_acc = 0.0for i, data in enumerate(train_loader, 1): # 枚舉函數enumerate返回下標和值img, label = dataimg = img.view(img.size(0), -1) # 將圖片展開為28*28# 使用GPU?if use_gpu:img = img.cuda()label = label.cuda()# 向前傳播out = model(img) # 前向傳播loss = criterion(out, label) # 計算lossrunning_loss += loss.item() # loss求和_, pred = torch.max(out, 1)running_acc += (pred == label).float().mean()# 向后傳播optimizer.zero_grad() # 梯度歸零loss.backward() # 后向傳播optimizer.step() # 更新參數if i % 300 == 0:print(f'[{epoch+1}/{num_epochs}] Loss: {running_loss/i:.6f}, Acc: {running_acc/i:.6f}')print(f'Finish {epoch+1} epoch, Loss: {running_loss/i:.6f}, Acc: {running_acc/i:.6f}')4 模型測試
特別注意的是需要用 model.eval(),讓model變成測試模式,這主要是對dropout和batch:
## 模型測試model.eval() # 讓模型變成測試模式eval_loss = 0.eval_acc = 0.for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if use_gpu:img = img.cuda()label = label.cuda()with torch.no_grad():out = model(img)loss = criterion(out, label)eval_loss += loss.item()_, pred = torch.max(out, 1)eval_acc += (pred == label).float().mean()print(f'Test Loss: {eval_loss/len(test_loader):.6f}, Acc: {eval_acc/len(test_loader):.6f}\n')1.在pytorch中的view()函數就是用來改變tensor的形狀的,例如將2行3列的tensor變為1行6列, view( )相當于numpy中resize()的功能,但是用法可能不太一樣。
2. 參數中的-1就代表這個位置由其他位置的數字來推斷,比如a tensor的數據個數是6個,如果view(1,-1),我們就可以根據tensor的元素個數推斷出-1代表6。
5 模型保存
# 保存模型 torch.save(model.state_dict(), './neural_network.pth')部分運行結果:
6 參考博客
https://www.zhihu.com/column/c_94953554
https://blog.csdn.net/gdymind/article/details/82226509
https://blog.csdn.net/jzwong/article/details/113308158
https://blog.csdn.net/qq_38929105/article/details/106438045
https://blog.csdn.net/york1996/article/details/81949843
https://blog.csdn.net/york1996/article/details/81949843
總結
以上是生活随笔為你收集整理的使用Pytorch快速搭建神经网络模型(附详细注释和讲解)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Simditor + Strust 上传
- 下一篇: 整理推荐比较好用的具有书签搜索功能的ch