LSTM简介以及数学推导(FULL BPTT)
前段時(shí)間看了一些關(guān)于LSTM方面的論文,一直準(zhǔn)備記錄一下學(xué)習(xí)過程的,因?yàn)槠渌聝?#xff0c;一直拖到了現(xiàn)在,記憶又快模糊了。現(xiàn)在趕緊補(bǔ)上,本文的組織安排是這樣的:先介紹rnn的BPTT所存在的問題,然后介紹最初的LSTM結(jié)構(gòu),在介紹加了遺忘控制門的,然后是加了peephole connections結(jié)構(gòu)的LSTM,都是按照真實(shí)提出的時(shí)間順序來寫的。本文相當(dāng)于把各個(gè)論文核心部分簡要匯集一下而做的筆記,已提供快速的了解。
一.rnn結(jié)構(gòu)的BPTT學(xué)習(xí)算法存在的問題
先看一下比較典型的BPTT一個(gè)展開的結(jié)構(gòu),如下圖,這里只考慮了部分圖,因?yàn)槠渌糠植皇沁@里要討論的內(nèi)容。
對(duì)于t時(shí)刻的誤差信號(hào)計(jì)算如下:
這樣權(quán)值的更新方式如下:
上面的公式在BPTT中是非常常見的了,那么如果這個(gè)誤差信號(hào)一直往過去傳呢,假設(shè)任意兩個(gè)節(jié)點(diǎn)u, v他們的關(guān)系是下面這樣的:
那么誤差傳遞信號(hào)的關(guān)系可以寫成如下的遞歸式:
n表示圖中一層神經(jīng)元的個(gè)數(shù),這個(gè)遞歸式的大概含義不難理解,要求t-q時(shí)刻誤差信號(hào)對(duì)t時(shí)刻誤差信號(hào)的偏導(dǎo),就先求出t-q+1時(shí)刻對(duì)t時(shí)刻的,然后把求出來的結(jié)果傳到t-q時(shí)刻,遞歸停止條件是q = 1時(shí),就是剛開始寫的那部分計(jì)算公式了。將上面的遞歸式展開后可以得到:
論文里面說的是可以通過歸納來證明,我沒仔細(xì)推敲這里了,把里面連乘展開看容易明白一點(diǎn):
整個(gè)結(jié)果式對(duì)T求和的次數(shù)是n^(q-1), 即T有n^(q-1)項(xiàng),那么下面看問題出在哪兒。
如果|T| > 1, 誤差就會(huì)隨著q的增大而呈指數(shù)增長,那么網(wǎng)絡(luò)的參數(shù)更新會(huì)引起非常大的震蕩。
如果|T| < 1, 誤差就會(huì)消失,導(dǎo)致學(xué)習(xí)無效,一般激活函數(shù)用simoid函數(shù),它的倒數(shù)最大值是0.25, 權(quán)值最大值要小于4才能保證不會(huì)小于1。
誤差呈指數(shù)增長的現(xiàn)象比較少,誤差消失在BPTT中很常見。在原論文中還有更詳細(xì)的數(shù)學(xué)分析,但是了解到此個(gè)人覺的已經(jīng)足夠理解問題所在了。
二.最初的LSTM結(jié)構(gòu)
為了克服誤差消失的問題,需要做一些限制,先假設(shè)僅僅只有一個(gè)神經(jīng)元與自己連接,簡圖如下:
根據(jù)上面的,t時(shí)刻的誤差信號(hào)計(jì)算如下:
為了使誤差不產(chǎn)生變化,可以強(qiáng)制令下式為1:
根據(jù)這個(gè)式子,可以得到:
這表示激活函數(shù)是線性的,常常的令fj(x) = x, wjj = 1.0,這樣就獲得常數(shù)誤差流了,也叫做CEC。
但是光是這樣是不行的,因?yàn)榇嬖谳斎胼敵鎏帣?quán)值更新的沖突(這里原論文里面的解釋我不是很明白),所以加上了兩道控制門,分別是input gate, output gate,來解決這個(gè)矛盾,圖如下:
圖中增加了兩個(gè)控制門,所謂控制的意思就是計(jì)算cec的輸入之前,乘以input gate的輸出,計(jì)算cec的輸出時(shí),將其結(jié)果乘以output gate的輸出,整個(gè)方框叫做block, 中間的小圓圈是CEC, 里面是一條y = x的直線表示該神經(jīng)元的激活函數(shù)是線性的,自連接的權(quán)重為1.0
三.增加forget gate
最初lstm結(jié)構(gòu)的一個(gè)缺點(diǎn)就是cec的狀態(tài)值可能會(huì)一直增大下去,增加forget gate后,可以對(duì)cec的狀態(tài)進(jìn)行控制,它的結(jié)構(gòu)如下圖:
這里的相當(dāng)于自連接權(quán)重不再是1.0,而是一個(gè)動(dòng)態(tài)的值,這個(gè)動(dòng)態(tài)值是forget gate的輸出值,它可以控制cec的狀態(tài)值,在必要時(shí)使之為0,即忘記作用,為1時(shí)和原來的結(jié)構(gòu)一樣。
四.增加Peephole的LSTM結(jié)構(gòu)
上面增加遺忘門一個(gè)缺點(diǎn)是當(dāng)前CEC的狀態(tài)不能影響到input gate, forget gate在下一時(shí)刻的輸出,所以增加了Peephole connections。結(jié)構(gòu)如下:
這里的gate的輸入部分就多加了一個(gè)來源了,forget gate, input gate的輸入來源增加了cec前一時(shí)刻的輸出,output gate的輸入來源增加了cec當(dāng)前時(shí)刻的輸出,另外計(jì)算的順序也必須保證如下:
五.一個(gè)LSTM的FULL BPTT推導(dǎo)(用誤差信號(hào))
我記得當(dāng)時(shí)看論文公式推導(dǎo)的時(shí)候很多地方比較難理解,最后隨便谷歌了幾下,找到一個(gè)寫的不錯(cuò)的類似課件的PDF,但是已經(jīng)不知道出處了,很容易就看懂LSTM的前向計(jì)算,誤差反傳更新了。把其中關(guān)于LSTM的部分放上來,首先網(wǎng)絡(luò)的完整結(jié)構(gòu)圖如下:這個(gè)結(jié)構(gòu)也是rwthlm源碼包中LSTM的結(jié)構(gòu),下面看一下公式的記號(hào):
- wij表示從神經(jīng)元i到j(luò)的連接權(quán)重(注意這和很多論文的表示是反著的)
- 神經(jīng)元的輸入用a表示,輸出用b表示
- 下標(biāo)?ι, φ 和 ω分別表示input gate, forget gate,output gate?
- c下標(biāo)表示cell,從cell到?input, forget和output gate的peephole權(quán)重分別記做 ?wcι , wcφ and wcω
- Sc表示cell c的狀態(tài)
- 控制門的激活函數(shù)用f表示,g,h分別表示cell的輸入輸出激活函數(shù)
- I表示輸入層的神經(jīng)元的個(gè)數(shù),K是輸出層的神經(jīng)元個(gè)數(shù),H是隱層cell的個(gè)數(shù)
誤差反傳更新:
原文地址:http://blog.csdn.net/a635661820/article/details/45390671
總結(jié)
以上是生活随笔為你收集整理的LSTM简介以及数学推导(FULL BPTT)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: CS231n官方笔记授权翻译总集篇发布
- 下一篇: Statistical language