PyTorch 笔记(14)— nn.module 实现简单感知机和多层感知机
autograd 實現了自動微分系統,然而對深度學習來說過于底層,而本節將介紹 nn 模塊,是構建于 autograd 之上的神經網絡模塊。
1. 簡單感知機
使用 autograd 可實現深度學習模型,但其抽象程度較低,如果用其來實現深度學習模型,則需要編寫的代碼量極大。在這種情況下,torch.nn 應運而生,其是專門為深度學習設計的模塊。
torch.nn 的核心數據結構是 Module ,它是一個抽象的概念,既可以表示神經網絡中的某個層(layer),也可以表示一個包含很多層的神經網絡。
在實際使用中,最常見的做法繼承 nn.Module ,撰寫自己的網絡層。
下面先來看看如何使用 nn.Module 實現自己的全連接層。全連接層,又名仿射層,輸入 y 和輸入 x 滿足y=xW+b ,W 和 b 是可學習的參數。
import torch as t
from torch import nnclass Linear(nn.Module):def __init__(self, input_features, out_features):super(Linear, self).__init__() # 等價于 nn.Module.__init__(self)self.w = nn.Parameter(t.randn(input_features, out_features))self.b = nn.Parameter(t.randn(out_features))def forward(self, x):x = x.mm(self.w)return x + self.b.expand_as(x)layer = Linear(4, 3)x = t.randn(2, 4)
output = layer(x)
print outputfor name, parameter in layer.named_parameters():print name, parameter
output 輸出為 :
tensor([[ 1.5752, 0.6730, -0.0763],[-0.7037, -0.6641, -2.3261]], grad_fn=<ThAddBackward>)
name, parameter 輸出為:
w Parameter containing:
tensor([[-1.0459, -0.1899, 0.2202],[ 1.5751, 0.0613, 1.7350],[-0.2644, 0.7728, 1.4141],[-0.3739, -0.4349, -0.0984]], requires_grad=True)
b Parameter containing:
tensor([1.3054, 0.3063, 0.4375], requires_grad=True)
可見,全連接層的實現非常簡單,但需注意以下幾點:
- 自定義層
Linear必須繼承nn.Module,并且在其構造函數中需調用nn.Module的構造函數,即super(Linear,self).__init()__或nn.Module.__init(self)__; - 在構造函數
__init__中必須自己定義可學習的參數,并封裝成Parameter,如在本例中我們把w和b封裝成Parameter。Parameter是一種特殊的Variable,但其默認需要求導(requires_grad=True); forward函數實現前向傳播過程,其輸入可以是一個或多個variable,對x的任何操作也必須是variable支持的操作。- 無須寫反向傳播函數,因其前向傳播都是對
variable進行操作,nn.Module能夠利用autograd自動實現反向傳播,這一點比Function簡單許多。 - 使用時,直觀上可將
layer看成數學概念中的函數,調用layer(input)即可得到input對應的結果。它等價于layers.__call(input)__,在__call__函數中,主要調用的是layer.forward(x)。所以在實際使用中應盡量使用layer(x)而不是使用layer.forward(x)。 Module中的可學習參數可以通過named_parameters()或者parameters()返回迭代器,前者會給每個parameter附上名字,使其更具有辨識度。
可見,利用 Module 實現的全連接層,比利用 Function 實現的更簡單,因其不再需要寫反向傳播函數。
2. 多層感知機
Module 能夠自動檢測到自己的 parameter ,并將其作為學習參數。除了 parameter,Module 還包含子Module ,主 Module 能夠遞歸查找子 Module 中的 parameter 。下面再來看看稍微復雜一點的網絡:多層感知機。
多層感知機的網絡結構如圖所示。它由兩個全連接層組成,采用 sigmoid 函數作為激活函數(圖中沒有畫出)。
實現代碼如下:
import torch as t
from torch import nnclass Linear(nn.Module):def __init__(self, input_features, out_features):super(Linear, self).__init__() # 等價于 nn.Module.__init__(self)self.w = nn.Parameter(t.randn(input_features, out_features))self.b = nn.Parameter(t.randn(out_features))def forward(self, x):x = x.mm(self.w)return x + self.b.expand_as(x)class Perceptron(nn.Module):def __init__(self, in_features, hidden_features, out_features):nn.Module.__init__(self)self.layer1 = Linear(in_features, hidden_features) # 此處的 Linear 前面自定義的全連接層self.layer2 = Linear(hidden_features, out_features)def forward(self, x):x = self.layer1(x)x = t.sigmoid(x)return self.layer2(x)perception = Perceptron(3,4,1)
for name, param in perception.named_parameters():print(name, param.size())
輸出結果:
layer1.w torch.Size([3, 4])
layer1.b torch.Size([4])
layer2.w torch.Size([4, 1])
layer2.b torch.Size([1])
可見,即使是稍復雜的多層感知機,其實現依舊很簡單。這里需要注意以下兩個知識點。
-
構造函數
__init__中,可利用前面自定義的Linear層(Module)作為當前Module對象的一個子Module,它的可學習參數,也會成為當前Module的可學習參數。 -
在前向傳播函數中,我們有意識地將輸出變量都命名為
x,是為了能讓Python回收一些中間層的輸出,從而節省內存。但并不是所有的中間結果都會被回收,有些variable雖然名字被覆蓋,但其在反向傳播時仍需要用到,此時Python的內存回收模塊將通過檢查引用計數,不會回收這一部分內存。
Module 中 parameter 的全局命名規范如下:
Parameter直接命名。例如self.param_name = nn.Parameter(t.randn(3,4)),命名為param_name。- 子
Module中的parameter,會在其名字之前加上當前Module的名字。例如self.sub_module = SubModule(),SubModule中有個parameter的名字也叫作param_name,那么二者拼接而成的parameter name就是sub_module.param_name。
為了方便用戶使用,PyTorch 實現了神經網絡中絕大多數的 layer ,這些 layer 都繼承于 nn.Module ,封裝了可學習參數 parameter ,并實現了 forward 函數,且專門針對 GPU 運算進行了 CuDNN 優化,其速度和性能都十分優異。
- 構造函數的參數,如
nn.Linear(in_features,out_features,bias),需關注這三個參數的作用。 - 屬性、可學習參數和子
Module。如nn.Linear中有weight和bias兩個可學習參數,不包含子Module。 - 輸入輸出的形狀,如
nn.Linear的輸入形狀是(N,input_features),輸出形狀為(N,output_features),N是batch_size。
這些自定義 layer 對輸入形狀都有假設:輸入的不是單個數據,而是一個 batch 。若想輸入一個數據,必須調用 unsqueeze(0) 函數將數據偽裝成 batch_size=1 的 batch 。
總結
以上是生活随笔為你收集整理的PyTorch 笔记(14)— nn.module 实现简单感知机和多层感知机的全部內容,希望文章能夠幫你解決所遇到的問題。