(pytorch-深度学习系列)正向传播与反向传播-学习笔记
正向傳播與反向傳播
1. 正向傳播
正向傳播是指對神經網絡沿著從輸入層到輸出層的順序,依次計算并存儲模型的中間變量(包括輸出)。
假設輸入是一個特征為x∈Rd\boldsymbol{x} \in \mathbb{R}^dx∈Rd的樣本,且不考慮偏差項,那么中間變量
z=W(1)x,\boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x},z=W(1)x,
(矩陣相乘)
其中W(1)∈Rh×d\boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d}W(1)∈Rh×d是隱藏層的權重參數。把中間變量z∈Rh\boldsymbol{z} \in \mathbb{R}^hz∈Rh輸入按元素運算的激活函數?\phi?后,將得到向量長度為hhh的隱藏層變量
h=?(z).\boldsymbol{h} = \phi (\boldsymbol{z}).h=?(z).
隱藏層變量h\boldsymbol{h}h也是一個中間變量。假設輸出層參數只有權重W(2)∈Rq×h\boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h}W(2)∈Rq×h,可以得到向量長度為qqq的輸出層變量
o=W(2)h.\boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h}.o=W(2)h.
假設損失函數為?\ell?,且樣本標簽為yyy,可以計算出單個數據樣本的損失項
L=?(o,y).L = \ell(\boldsymbol{o}, y).L=?(o,y).
根據L2L_2L2?范數正則化的定義,給定超參數λ\lambdaλ,正則化項即(超參數λ\lambdaλ即表示懲罰的力度)
s=λ2(∣W(1)∣F2+∣W(2)∣F2),s = \frac{\lambda}{2} \left(|\boldsymbol{W}^{(1)}|_F^2 + |\boldsymbol{W}^{(2)}|_F^2\right),s=2λ?(∣W(1)∣F2?+∣W(2)∣F2?),
其中矩陣的Frobenius范數等價于將矩陣變平為向量后計算L2L_2L2?范數。最終,模型在給定的數據樣本上帶正則化的損失為
J=L+s.J = L + s.J=L+s.
我們將JJJ稱為有關給定數據樣本的目標函數。
2. 反向傳播
反向傳播用于計算神經網絡中的參數梯度。反向傳播利用微積分中的鏈式法則,沿著從輸出層到輸入層的順序進行依次計算目標函數有關神經網絡各層的中間變量以及參數的梯度。
依據鏈式法則,我們可以知道:
?J?o=prod(?J?L,?L?o)=?L?o.\frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. ?o?J?=prod(?L?J?,?o?L?)=?o?L?.
(?J?L=1,?J?s=1)\left( \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1\right)(?L?J?=1,?s?J?=1)
其中prod\text{prod}prod運算符將根據兩個輸入的形狀,在必要的操作(如轉置和互換輸入位置)后對兩個輸入做乘法。
?J?W(2)=prod(?J?o,?o?W(2))+prod(?J?s,?s?W(2))=?J?oh?+λW(2)\frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ?W(2)?J?=prod(?o?J?,?W(2)?o?)+prod(?s?J?,?W(2)?s?)=?o?J?h?+λW(2)
其中:
(?s?W(1)=λW(1),?s?W(2)=λW(2))\left(\frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}\right)(?W(1)?s?=λW(1),?W(2)?s?=λW(2))
還有:
?J?W(2)=prod(?J?o,?o?W(2))+prod(?J?s,?s?W(2))=?J?oh?+λW(2)\frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ?W(2)?J?=prod(?o?J?,?W(2)?o?)+prod(?s?J?,?W(2)?s?)=?o?J?h?+λW(2)
?J?h=prod(?J?o,?o?h)=W(2)??J?o\frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}} ?h?J?=prod(?o?J?,?h?o?)=W(2)??o?J?
?J?z=prod(?J?h,?h?z)=?J?h⊙?′(z)\frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right) ?z?J?=prod(?h?J?,?z?h?)=?h?J?⊙?′(z)
所以,可以得到:
?J?W(1)=prod(?J?z,?z?W(1))+prod(?J?s,?s?W(1))=?J?zx?+λW(1)\frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}?W(1)?J?=prod(?z?J?,?W(1)?z?)+prod(?s?J?,?W(1)?s?)=?z?J?x?+λW(1)
- 在模型參數初始化完成后,需要交替地進行正向傳播和反向傳播,并根據反向傳播計算的梯度迭代模型參數。
- 在反向傳播中使用了正向傳播中計算得到的中間變量來避免重復計算,同時這個復用也導致正向傳播結束后不能立即釋放中間變量內存。這也是訓練要比預測占用更多內存的一個重要原因。
- 這些中間變量的個數大體上與網絡層數線性相關,每個變量的大小跟批量大小和輸入個數也是線性相關的,這是導致較深的神經網絡使用較大批量訓練時更容易超內存的主要原因。
總結
以上是生活随笔為你收集整理的(pytorch-深度学习系列)正向传播与反向传播-学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 标记分布学习与标记增强
- 下一篇: 从0到1 | 0基础/转行如何用3个月搞