普通RNN的缺陷—梯度消失和梯度爆炸
之前的RNN,無法很好地學習到時序數據的長期依賴關系。因為BPTT會發生梯度消失和梯度爆炸的問題。
RNN梯度消失和爆炸
對于RNN來說,輸入時序數據xt時,RNN 層輸出ht。這個ht稱為RNN 層的隱藏狀態,它記錄過去的信息。
語言模型的任務是根據已經出現的單詞預測下一個將要出現的單詞。
學習正確解標簽過程中,RNN層通過向過去傳遞有意義的梯度,能夠學習時間方向上的依賴關系。如果這個梯度在中途變弱(甚至沒有包含任何信息),權重參數將不會被更新,也就是所謂的RNN層無法學習長期的依賴關系。梯度的流動如下圖綠色箭頭。
隨著時間增加,RNN會產生梯度變小(梯度消失)或梯度變大(梯度爆炸)。
RNN 層在時間方向上的梯度傳播,如下圖。
反向傳播的梯度流經tanh、+、MatMul(矩陣乘積)運算。
+的反向傳播,將上游傳來的梯度原樣傳給下游,梯度值不變。
tanh的計算圖如下。它將上游傳來的梯度乘以tanh的導數傳給下游。
y=tanh(x)的值及其導數的值如下圖。導數值小于1,x越遠離0,值越小。反向傳播梯度經過tanh節點要乘上tanh的導數,這就導致梯度越來越小。
如果RNN層的激活函數使用ReLU,可以抑制梯度消失,當ReLU輸入為x時,輸出是max(0,x)。x大于0時,反向傳播將上游的梯度原樣傳遞到下游,梯度不會退化。
對于MatMul(矩陣乘積)節點。僅關注RNN層MatMul節點時的梯度反向傳播如下圖。每一次矩陣乘積計算都使用相同的權重Wh。
N = 2 # mini-batch的大小 H = 3 # 隱藏狀態向量的維數 T = 20 # 時序數據的長度dh = np.ones((N, H))#初始化為所有元素均為 1 的矩陣,dh是梯度np.random.seed(3)Wh = np.random.randn(H, H)#梯度的大小隨時間步長呈指數級增加,發生梯度爆炸 #Wh = np.random.randn(H, H) * 0.5 #梯度的大小隨時間步長呈指數級減小,發生梯度消失,權重梯度不能被更新,模型無法學習長期的依賴關系 norm_list = [] for t in range(T):dh = np.dot(dh, Wh.T)#根據反向傳播的 MatMul 節點的數量更新 dh 相應次數norm = np.sqrt(np.sum(dh**2)) / N#mini-batch(N)中的平均L2 范數,L2 范數對所有元素的平方和求平方根.norm_list.append(norm)#將各步的 dh 的大小(范數)添加到 norm_list 中print(norm_list)# 繪制圖形 plt.plot(np.arange(len(norm_list)), norm_list) plt.xticks([0, 4, 9, 14, 19], [1, 5, 10, 15, 20]) plt.xlabel('time step') plt.ylabel('norm') plt.show()如果Wh是標量,由于Wh被反復乘了T次,當Wh大于1時,梯度呈指數級增加;當 Wh 小于1時,梯度呈指數級減小。
如果wh是矩陣,矩陣的奇異值表示數據的離散程度,根據奇異值(多個奇異值中的最大值)是否大于1,可以預測梯度大小的變化。奇異值比1大是梯度爆炸的必要非充分條件。
梯度裁剪gradients clipping
梯度裁剪(gradients clipping)是解決解決梯度爆炸的一個方法。
將神經網絡用到的所有參數的梯度整合成一個,用g表示,將閾值設置為threshold,如果梯度g的L2范數大于等于該閾值,就按如下方式修正梯度。
dW1 = np.random.rand(3, 3) * 10 dW2 = np.random.rand(3, 3) * 10 grads = [dW1, dW2] max_norm = 5.0#閾值def clip_grads(grads, max_norm):total_norm = 0for grad in grads:total_norm += np.sum(grad ** 2)total_norm = np.sqrt(total_norm)#L2 范數對所有元素的平方和求平方根rate = max_norm / (total_norm + 1e-6)if rate < 1:#如果梯度的L2范數total_norm大于等于閾值max_norm,rate是小于1的,此時就需要修正梯度for grad in grads:grad *= rateprint('before:', dW1.flatten()) clip_grads(grads, max_norm) print('after:', dW1.flatten()) before: [7.14418135 3.58857143 7.82910303 8.04057218 8.8617387 1.899638863.0606848 8.14163088 5.25490409] after: [1.43122195 0.71891263 1.56843501 1.61079946 1.77530697 0.380562130.61315903 1.63104494 1.05273561]解決梯度消失
為了解決梯度消失,需要從根本上改變 RNN 層的結構。
LSTM 和GRU中增加了一種門結構,可以學習到時序數據的長期依賴關系。
總結
以上是生活随笔為你收集整理的普通RNN的缺陷—梯度消失和梯度爆炸的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用权值衰减算法解决神经网络过拟合问题、
- 下一篇: Java 抽象类和抽象方法