一文详解pytorch的“动态图”与“自动微分”技术
前言
眾所周知,Pytorch是一個非常流行且深受好評的深度學習訓練框架。這與它的兩大特性“動態圖”、“自動微分”有非常大的關系。“動態圖”使得pytorch的調試非常簡單,每一個步驟,每一個流程都可以被我們精確的控制、調試、輸出。甚至是在每個迭代都能夠重構整個網絡。這在其他基于靜態圖的訓練框架中是非常不方便處理的。在靜態圖的訓練框架中,必須先構建好整個網絡,然后開始訓練。如果想在訓練過程中輸出中間節點的數據或者是想要改變一點網絡的結構,就需要非常復雜的操作,甚至是不可實現的。而“自動微分”技術使得在編寫深度學習網絡的時候,只需要實現算子的前向傳播即可,無需像caffe那樣對同一個算子需要同時實現前向傳播和反向傳播。由于反向傳播一般比前向傳播要復雜,并且手動推導反向傳播的時候很容易出錯,所以“自動微分”能夠極大的節約勞動力,提升效率。
動態圖
用過caffe或者tensorflow的同學應該知道,在訓練之前需要構建一個神經網絡,caffe里面使用配置文件prototxt來進行描述,tensorflow中使用python代碼來描述。訓練之前,框架都會有一個解析和構建神經網絡的過程。構建完了之后再進行數據讀取和訓練。在訓練過程中網絡一般是不會變的,所以叫做“靜態圖”。想要獲取中間變量的輸出,可以是可以,就是比較麻煩一些,caffe使用c++訓練的話,需要獲取layer的top,然后打印,tensorflow需要通過session來獲取。但是如果想要控制網絡的運行,比如讓網絡停在某一個OP之后,這是很難做到的。即無法精確的控制網絡運行的每一步,只能等網絡運行完了,然后通過相關的接口去獲取相關的數據。而pytorch的“動態圖”機制就可以對網絡實現非常精確的控制。在pytorch運行之前,不會去創建所謂的神經網絡,這完全由python代碼定義的forward函數來描述。即我們手工編寫的forward函數就是pytorch前向運行的動態圖。當代碼執行到哪一句的時候,網絡就運行到哪一步。所以當你對forward函數進行調試,斷點,修改的時候,神經網絡也就被相應的調試、中斷和修改了。也就是說pytorch的forwad代碼就是神經網絡的執行流,或者說就是pytorch的“動態圖”。對forward的控制就是對神經網絡的控制。如下圖所示:
正因為這樣的實現機制,使得對神經網絡的調試可以像普通python代碼那樣進行調試,非常的方便和友好。并且可以在任何時候,修改網絡的結構,這就是動態圖的好處。
自動微分
上面的動態圖詳解了pytorch如何構建前向傳播的動態神經網絡的,實際上pytorch并沒有顯式的去構建一個所謂的動態圖,本質就是按照forward的代碼執行流程走了一遍而已。那么對于反向傳播,因為我們沒有構建反向傳播的代碼,pytorch也就無法像前向傳播那樣,通過我們手動編寫的代碼執行流進行反向傳播。那么pytorch是如何實現精確的反向傳播的呢?其實最大的奧秘就藏在tensor的grad_fn屬性里面。有的同學可能在調試pytorch代碼的時候已經不經意的遇到過這個grad_fn屬性。如下圖所示:
Pytorch中的tensor對象都有一個叫做grad_fn的屬性,它實際上是一個鏈表,實現在pytorch源碼的autograd下面。該屬性記錄了該tensor是如何由前一個tensor產生的。在深入探究grad_fn之前,先來了解一下pytroch中的leaf tensor和非leaf tensor。
?
Leaf/非leaf tensor:
Pytorch中的tensor有兩種產生方式,一種是憑空創建的,例如一些op里面的params,訓練的images,這些tensor,他們不是由其他tensor計算得來的,而是通過torch.zeros(),torch.ones(),torch.from_numpy()等憑空創建出來的。另外一種產生方式是由某一個tensor經過一個op計算得到,例如tensor a通過conv計算得到tensor b。其實這兩種op創建方式對應的就是leaf節點(葉子節點)和非leaf(非葉子節點)。如下圖所示,為一個cnn網絡中的leaf節點和非leaf節點。黃色的節點對應的tensor就是憑空生成的,是leaf節點;藍色的tensor就是通過其他tensor計算得來的,是非leaf節點。那么顯而易見,藍色的非leaf節點的grad_fn是有值的,因為它的梯度需要繼續向后傳播給創建它的那個節點。而黃色的leaf節點的grad_fn為None,因為他們不是由其他節點創建而來,他們的梯度不需要繼續反向傳播。
深究grad_fn:
grad_fn是python層的封裝,其實現對應的就是pytorch源碼在autograd下面的node對象,為C++實現,如下圖所示:
node其實是一個鏈表,有一個next_edges_屬性,里面存儲著指向下一級的所有node。注意它不是一個簡單的單向鏈表,因為很多tensor可能是由多個tensor創建來的。例如tensor a = tensor b + tensor c. 那么tensor a的grad_fn屬性里面的next_edges就會有兩個指針,分別指向tensor b和tensor c的grad_fn屬性。在python層,next_edges_屬性被封裝成了next_functions。因此正確的說法是:tensor a的grad_fn屬性里面的next_ functions,指向了tensor b和tensor c的grad_fn屬性。其實有了這個完整的鏈表,就已經完整的表達了反向傳播的計算圖。就可以完成完整的反向傳播了。 下面我們通過一個小例子來進一步說明grad_fn是如何表達反向傳播計算圖的。首先我們定義一個非常簡單的網絡:僅有兩個conv層,一個relu層,一個pool層,如下圖所示(conv層帶有參數weights和bias):
對應的代碼片段如下所示:
class TinyCnn(torch.nn.Module):def __init__(self, arg_dict={}):super(TinyCnn, self).__init__()self.conv = torch.nn.Conv2d(3, 3, kernel_size=2, stride=2)self.relu = torch.nn.ReLU(inplace=True)self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)def forward(self, images):conv_out = self.conv(images)relu_out = self.relu(conv_out)pool_out = self.pool(relu_out)return pool_outcnn = TinyCnn() loss_fun = torch.nn.BCELoss() images = torch.rand((1,3,4,4)) labels = torch.rand((1,3,1,1)) preds = cnn(images) loss = loss_fun(preds, labels) loss.backward()那么當代碼執行到loss = loss_fun(preds, labels),我們看看loss的grad_fn以及其對應的next_functions:
可以看到loss的grad_fn為:<BinaryCrossEntropyBackward object at 0x000001A07E079FC8>,而它的next_functions為:(<MaxPool2DWithIndicesBackward object at 0x000001A07E08BC88>, 0),繼續跟蹤MaxPool2DWithIndicesBackward的nex_functions為:(<ReluBackward1 object at 0x000001A07E079B88>, 0),如果繼續跟蹤下去,整個反向傳播的計算圖就非常的直觀了,使用下圖表示:
Images由于是葉子節點,且不需要求梯度,因此ThnnConv2DBackward的第一個next_functions對應的是None。而conv中的weights和bias雖然也是葉子節點,但是需要求梯度,因此增加了一個AccumulateGrad的方法,表示可累計梯度,實際上就是對weights和bias的梯度的保存。
grad_fn是如何生成的?
有了上面的介紹,其實大家應該已經大致了解了pytorch自動微分的大致流程。實際上是通過tensor的gran_fn來組織的,grad_fn本質上是一個鏈表,指向下一級別的tensor的grad_fn,因此通過這樣一個鏈表構成了一個完整的反向計算的動態圖。那么最后有一個問題就是tensor的grad_fn是如何構建的?無論是我們自己編寫的上層代碼,還是在pytorch底層的op實現里面,并沒有顯示的去創建grad_fn,那么它是在何時,又是如何組裝的?實際上通過編譯pytorch源碼就能發現端倪。Pytorch會對所有底層算子進一個二次封裝,在做完正常的op前向之后,增加了grad_fn的設置,next_functions的設置等流程。如下圖所示為原始卷積的前向流程和經過pytroch自動封裝的卷積前向計算流程對比。可以明顯的看到多了一些對grad_fn設置的代碼。
后記
以上流程就是pytorch的“動態圖”與“自動微分”的核心邏輯。基于pytorch1.6.0源碼分析,由于作者才疏學淺,且涉獵范圍有限,難免有所錯誤,如果有不對的地方,還請見諒并指正。
?
總結
以上是生活随笔為你收集整理的一文详解pytorch的“动态图”与“自动微分”技术的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 像素密度推高到 5000 PPI!麻省理
- 下一篇: 宝马 i4 车主收到通知:车停在陡坡上无