搞懂RNN
文章目錄
- 1 什么是RNN
- 2 LSTM
- 3 Training
- 3.1 Learning Target
- 3.2 為什么難train
- 4 應用舉例
- 4.1 Many To One
- 4.2 Many To Many
- 4.3 其他
本文為李弘毅老師【Recurrent Neural Network(Part I)】和【Recurrent Neural Network(Part II)】的課程筆記,課程視頻來源于youtube(需翻墻)。
下文中用到的圖片均來自于李宏毅老師的PPT,若有侵權,必定刪除。
1 什么是RNN
顧名思義,RNN就是一個不斷循環的神經網絡,只不過它在循環的過程當中是有記憶的,這也是發明RNN的初衷,就是希望神經網絡在看一個序列的輸入的時候可以考慮一下前面看過的內容。
我們以Slot Filling來舉例。Slot Filiing就是填空的意思,比如今天有一個旅客說了一句"I would like to arrive Taipei on November 2nd",我們的訂票系統就需要從這句話中找到"目的地"和"期望的到達時間"這兩個slot。而幫助訂票系統來解析這個句子的,也就是我們的RNN模型。
RNN是怎么運作的呢?簡單粗暴地來說,就比如我們先往RNN里塞進一個"arrive"和一個隨機初始化的狀態向量a0a^0a0,然后RNN會輸出一個output(yiy^iyi)和一個hidden state(aia^iai)。output用來表示"arrive"這個詞在每個slot中的概率,hidden state是一個向量,包含著看過"arrive"之后,模型自己記下來的信息,然后再把這個hidden state和下一個輸入"Taipei"塞進RNN里,如此反復,直到把整個句子看完。這就是一個最最最簡單的RNN的運作流程,aia^iai就表示著RNN的記憶。不過也正是因為RNN需要記憶,所以RNN沒法并行計算。
市面上兩種RNN的記憶方法,一種是利用hidden state來當作記憶,叫做"Elman Network",另一種是利用"output"當作記憶,叫做"Jordan Network"。據說,output是有label在監督的,不像hidden state那么自由,所以"Jordan Network"的效果會好一些。不過,市面上也是有把兩者結合起來的記憶方法的,名字就不知道了。
當然,這樣的記憶只是單向的,有些時候句子的理解是需要句子后面的一些詞匯的輔助的。為了解決這個問題,也就有了雙向的RNN。雙向的RNN的兩個方向是可以并行計算的。下面這幅圖應該是比較清楚的了,每個output是結合了從頭到尾和從尾到頭兩個方向的output。
2 LSTM
市面上在使用的RNN并不是上述那么簡單的,其中LSTM是比較常用的一種方法。LSTM的每一個cell的結構如下圖所示,它吃四個input,吐出一個output。四個input分別是該time step的輸入、控制input gate的信號、控制forget gate的信號以及控制output gate的信號。最中間的那個memory cell是用來存儲之前序列留下的記憶信息的。input gate用來決定要不要接收這個輸入,forget gate用來決定要不要使用之前的記憶,output gate用來決定要不要輸出這個輸出。
這么說可能還是有點糊涂,看下面這張圖吧。比如我們某個time step的輸入為xtx^txt,首先這個xtx^txt會分別乘以一個矩陣得到LSTM需要的四個輸入。
z=Wxt+bzf=Wfxt+bfzi=Wixt+bizo=Woxt+boz = Wx^t+b\\ z^f = W^fx^t+b^f\\ z^i = W^ix^t+b^i\\ z^o = W^ox^t+b^o z=Wxt+bzf=Wfxt+bfzi=Wixt+bizo=Woxt+bo
然后,ziz^izi會經過一個激活層,來控制輸入zzz,我們把這個中間量記作inputinputinput吧。
input=σ(zi)?zinput = \sigma(z^i)*z input=σ(zi)?z
同時,zfz^fzf也會經過一個激活層,來決定是否要使用之前的記憶ct?1c^{t-1}ct?1,我們記這個中間變量叫做memorymemorymemory吧。
memory=σ(zf)?ct?1memory = \sigma(z^f)*c^{t-1} memory=σ(zf)?ct?1
這個inputinputinput和memorymemorymemory會相加在一起,作為輸出的結果,這個結果由經過一層激活層的zoz^ozo來控制。這個輸出我們叫做hidden state(hhh)。
h=σ(zo)?σ(input+memory)h = \sigma(z^o)*\sigma(input+memory) h=σ(zo)?σ(input+memory)
最后的結果yty^tyt一般還要來一層全連接。
yt=Wyhy^t = W^yh yt=Wyh
不過這只是一個time step,LSTM在多個循環的時候,長下面這樣。
可見,實際情況下,我們的輸入并不是xtx^txt,而是xtx^txt,hth^tht和ct?1c^{t-1}ct?1的結合。其中,利用ct?1c^{t-1}ct?1的這個操作被稱為peehole。
而現在主流的框架之間的實現,也略有差別,比如pytorch的實現就沒有利用peehole,激活層也有一些區別,不過整體的思路是完全一致的。
3 Training
3.1 Learning Target
在進行訓練的時候,我們需要一個目標。這個目標其實是需要根據實際的應用場景來定的。比如,我們還是用上面訂票系統的例子,我們的每一個time step的輸出是一個概率向量,分別表示著[“other”, “dest”, “time”]的概率大小。我們的label就是一個one-hot encoding的向量,比如"arrive"的label中"other"的標簽為1,其他為0;"Taipei"的label中"dest"的標簽為1,其他為0;"on"的label中"other"的標簽為1,其他為0;以此類推。然后用prediction算下cross entrophy loss就行了。
而RNN的反向傳播也是和其他的神經網絡一樣,是可以用梯度下降來做的。不過,因為它是有時間順序的,所以計算時略有不同,得要用一個叫做BPTT(Backpropagation through time)的方法來做,這里不詳述了。總之就是可以和其他網絡一樣train下去。想了解的可以看下吳恩達的RNN W1L04 : Backpropagation through time。
3.2 為什么難train
雖然RNN也是和正常的神經網絡那樣可以用gradient descent不斷地更新參數來train下去,但在RNN剛出來的時候,幾乎沒有人可以把它train出來,往往會得到一條如下圖綠色曲線這樣的結果。只有一個叫做Razvan Pascanu的人,可以train出那條藍色的曲線。其實原因時因為RNN的loss space非常陡峭,參數微小的變動,可能引起loss極大的改變。Razvan Pascanu在他寫博士論文的時候,把他一直可以train出好結果的秘訣公布了出來,那就是gradient clipping,即人為地把gradient的大小限制住了。沒錯就是這么簡單的一個技巧。
但究竟為什么RNN的梯度會發生這么大的變化呢?我們來舉個例子說明一下。假如我們有一個全世界最簡單的RNN,它輸入的weight是1,輸出的weight也是1,用來memory的weight為www,那么當我們輸入一個長度為1000,且只有第一個元素為1,其余都為0的序列時,最后一個time step的輸出就為y1000=w999y^{1000}=w^{999}y1000=w999。
這是一個什么概念呢?比如我們的w=1w=1w=1,那么y1000=1y^{1000}=1y1000=1。而此時,www只要稍稍變大一點,那么y1000y^{1000}y1000就會產生很大的變化,比如w=1.01w=1.01w=1.01時,y1000=20000y^{1000}=20000y1000=20000。這時候也就會發生所謂的梯度爆炸。而當www降到1一下時,y1000y^{1000}y1000又一直變為0了,這也就是所謂的梯度消失。
而這一切的一切都是因為RNN的參數在循環的過程中被不斷的重復使用。
LSTM在一定程度上是可以解決梯度消失的問題的。為什么LSTM可以解決梯度消失?我感覺李老師這里還是有點沒講清楚,推薦看下這篇blog(需翻墻)。一句話就是,傳統的RNN的反向傳播操作(BPTT)有個連乘的東西在,在LSTM的BPTT里是連加的,然后LSTM又有forget gate這個門在,可以使得相隔較遠的序列不相互影響。
簡單來說就是用clip可以緩解梯度爆炸,用LSTM可以緩解梯度消失。
我有一點很想知道的是,發明LSTM的大佬,是為了解決梯度消失的問題而發明了LSTM呢?還是發明了LSTM之后,發現可以緩解梯度消失?這個也真是太6了~
4 應用舉例
RNN的應用范圍非常廣泛,這里簡單列舉一下李老師視頻中提到的一些例子。
4.1 Many To One
- 情感分析。輸入一段評論,輸出該段評論是好評還是差評。
- 關鍵信息提取。輸入一篇文章,輸出該文章中的關鍵信息。
4.2 Many To Many
- 語音識別。輸入一段語音,輸出對應的文字。
- 語言翻譯。輸入一段某國的文字或語音,輸出一段另一個國家的對應意思的文字或語音。
- 聊天機器人。輸入一句話,輸出回答。
4.3 其他
- 句子文法結果分析。輸入一個句子,輸出該句子的文法結構。
- 句子自編碼。
總結
- 上一篇: MapReduce 编程实践:统计对象中
- 下一篇: LeetCode 1878. 矩阵中最大