基于torch.nn.functional.conv2d实现CNN
在我們之前的實驗中,我們一直用torch.nn.Conv2D來實現卷積神經網絡,但是torch.nn.Conv2D在實現中是以torch.nn.functional.conv2d為基礎的,這兩者的區別是什么呢?
torch.nn.Conv2D
源碼如下:
torch.nn.Conv2dCLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')可以發現,函數參數包括輸入的通道數、輸出的通道數、卷積核大小等。在輸入中,我們不需要輸入卷積核的權重,但是如果在實驗中,我們需要用自己的卷積核,那么這種方式就不適用了。
torch.nn.functional.conv2d
源碼如下:
torch.nn.functionaltorch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor參數的具體意義:
input代表輸入圖像的大小(minibatch,in_channels,H,W),是一個四維tensor
filters代表卷積核的大小(out_channels,in_channe/groups,H,W),是一個四維tensor
bias代表每一個channel的bias,是一個維數等于out_channels的tensor
stride是一個數或者一個二元組(SH,SW),代表縱向和橫向的步長
padding是一個數或者一個二元組(PH,PW ),代表縱向和橫向的填充值
dilation是一個數,代表卷積核內部每個元素之間間隔元素的數目(不常用,默認為0)
groups是一個數,代表分組卷積時分的組數,特別的當groups = in_channel時,就是在做逐層卷積(depth-wise conv).
二者區別
torch.nn.Conv2D是一個類,而torch.nn.functional.conv2d是一個函數,在Sequential里面只能放nn.xxx,而nn.functional.xxx是不能放入Sequential里面的。
nn.Module 實現的 layer 是由 class Layer(nn.Module) 定義的特殊類,nn.functional 中的函數是純函數,由 def function(input) 定義。
nn.functional.xxx 需要自己定義 weight,每次調用時都需要手動傳入 weight,而 nn.xxx 則不用。
如果需要自己定義卷積核,那么就只能使用nn.functional.conv2d。但是在使用時,需要注意BatchNormalization和Dropout的使用方式。參考以下鏈接
接下來我們使用torch.nn.functional.conv2d來定義CNN實現Mnist數據集的識別,CNN定義如下所示:
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16,1,3,3))self.bias_1_weight = nn.Parameter(torch.randn(16))self.bn1 = nn.BatchNorm2d(16)self.conv_2_weight = nn.Parameter(torch.randn(32,16,3,3))self.bias_2_weight = nn.Parameter(torch.randn(32))self.bn2 = nn.BatchNorm2d(32)self.Linear_weight = nn.Parameter(torch.randn(10,32*32*32))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self,x):x = F.conv2d(x,self.conv_1_weight,self.bias_1_weight,stride=1,padding=1)x = F.relu(self.bn1(x),inplace=True)x = F.conv2d(x,self.conv_2_weight,self.bias_2_weight,stride=1,padding=1)x = F.relu(self.bn2(x),inplace=True)x = x.view(-1,32*32*32)x = F.linear(x,self.Linear_weight,self.bias_weight)return x實驗結果如下所示,最終的模型準確率為97%:
Epoch: 29 | Train Loss: 0.1965 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.1152 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.0702 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.0971 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.1620 | Test Accuracy: 0.97全部代碼如下所示:
import torch import torch.nn as nn import torch.nn.functional as F import torchvision from data import Getdata from torch import optimdata_train_loader,data_test_loader = Getdata() class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16,1,3,3))self.bias_1_weight = nn.Parameter(torch.randn(16))self.bn1 = nn.BatchNorm2d(16)self.conv_2_weight = nn.Parameter(torch.randn(32,16,3,3))self.bias_2_weight = nn.Parameter(torch.randn(32))self.bn2 = nn.BatchNorm2d(32)self.Linear_weight = nn.Parameter(torch.randn(10,32*32*32))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self,x):x = F.conv2d(x,self.conv_1_weight,self.bias_1_weight,stride=1,padding=1)x = F.relu(self.bn1(x),inplace=True)x = F.conv2d(x,self.conv_2_weight,self.bias_2_weight,stride=1,padding=1)x = F.relu(self.bn2(x),inplace=True)x = x.view(-1,32*32*32)x = F.linear(x,self.Linear_weight,self.bias_weight)return x model = CNN()optimizer = torch.optim.Adam(model.parameters(),lr=1e-3) loss_func = nn.CrossEntropyLoss() epoch = 30for i in range(epoch):for step,(train_x,train_y) in enumerate(data_train_loader):model.train()output = model(train_x)loss = loss_func(output,train_y)optimizer.zero_grad()loss.backward()optimizer.step()if step % 50 == 0:model.eval()with torch.no_grad():test_acc = 0num = 0for s,(test_x,test_y) in enumerate(data_test_loader):output = model(test_x)output = output.int()pred_y = torch.max((output),dim=1)[1]test_acc += test_y.eq_(pred_y).sum().item()num += test_y.size(0)print('Epoch: ',i,'| Train Loss: %.4f'% loss.item(),'| Test Accuracy: %.2f' % float(test_acc / num))努力加油a啊
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的基于torch.nn.functional.conv2d实现CNN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 蓝牙耳机有爆炸风险吗(蓝牙无线技术)
- 下一篇: 查看cuda版本(如何查看cuda的版本