控制论python_[干货]深入浅出LSTM及其Python代码实现
人工神經(jīng)網(wǎng)絡(luò)在近年來大放異彩,在圖像識別、語音識別、自然語言處理與大數(shù)據(jù)分析領(lǐng)域取得了巨大的成功,而長短期記憶網(wǎng)絡(luò)LSTM作為一種特殊的神經(jīng)網(wǎng)絡(luò)模型,它又有哪些特點呢?作為初學(xué)者,如何由淺入深地理解LSTM并將其應(yīng)用到實際工作中呢?本文將由淺入深介紹循環(huán)神經(jīng)網(wǎng)絡(luò)RNN和長短期記憶網(wǎng)絡(luò)LSTM的基本原理,并基于Pytorch實現(xiàn)一個簡單應(yīng)用例子,提供完整代碼。
1. 神經(jīng)網(wǎng)絡(luò)簡介
1.1 神經(jīng)網(wǎng)絡(luò)起源
人工神經(jīng)網(wǎng)絡(luò)(Aritificial Neural Networks, ANN)是一種仿生的網(wǎng)絡(luò)結(jié)構(gòu),起源于對人類大腦的研究。人工神經(jīng)網(wǎng)絡(luò)(Aritificial Neural Networks)也常被簡稱為神經(jīng)網(wǎng)絡(luò)(Neural Networks, NN),基本思想是通過大量簡單的神經(jīng)元之間的相互連接來構(gòu)造復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu),信號(數(shù)據(jù))可以在這些神經(jīng)元之間傳遞,通過激活不同的神經(jīng)元和對傳遞的信號進(jìn)行加權(quán)來使得信號被放大或衰減,經(jīng)過多次的傳遞來改變信號的強(qiáng)度和表現(xiàn)形式。
神經(jīng)網(wǎng)絡(luò)最早起源于20世紀(jì)40年代,神經(jīng)科學(xué)家和控制論專家Warren McCulloch和邏輯學(xué)家Walter Pitts基于數(shù)學(xué)和閾值邏輯算法創(chuàng)造了最早的神經(jīng)網(wǎng)絡(luò)計算模型。由于當(dāng)時的計算資源有限,無法構(gòu)建層數(shù)太多的神經(jīng)網(wǎng)絡(luò)(3層以內(nèi)),因此神經(jīng)網(wǎng)絡(luò)的應(yīng)用范圍很局限。隨著計算機(jī)技術(shù)的發(fā)展,神經(jīng)網(wǎng)絡(luò)層數(shù)的增加帶來的計算負(fù)擔(dān)已經(jīng)可以被現(xiàn)代計算機(jī)解決,各位前輩大牛對于神經(jīng)網(wǎng)絡(luò)的理解也進(jìn)一步加深。歷史上神經(jīng)網(wǎng)絡(luò)的發(fā)展大致經(jīng)歷了三次高潮:20世紀(jì)40年代的控制論、20世紀(jì)80年代到90年代中期的聯(lián)結(jié)主義和2006年以來的深度學(xué)習(xí)。深度學(xué)習(xí)的出現(xiàn)直接引爆了一部分應(yīng)用市場,這里有太多的案例可以講,想了解的讀者可參考下面的鏈接:
1.2 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的缺陷
本文假設(shè)讀者已經(jīng)了解神經(jīng)網(wǎng)絡(luò)的基本原理了,如果有讀者是初次接觸神經(jīng)網(wǎng)絡(luò)的知識,這里分享一篇個人覺得非常適合初學(xué)者的文章:
1.2.1 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的原理回顧
傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)可以用下面這張圖表示:
NN.jpg
其中:
輸入層:可以包含多個神經(jīng)元,可以接收多維的信號輸入(特征信息);
輸出層:可以包含多個神經(jīng)元,可以輸出多維信號;
隱含層:可以包含多個神經(jīng)網(wǎng)絡(luò)層,每一層包含多個神經(jīng)元。
每層的神經(jīng)元與上一層神經(jīng)元和下一層神經(jīng)元連接(類似生物神經(jīng)元的突觸),這些連接通路用于信號傳遞。每個神經(jīng)元接收來自上一層的信號輸入,使用一定的加和規(guī)則將所有的信號輸入?yún)R聚到一起,并使用激活函數(shù)將輸入信號激活為輸出信號,再將信號傳遞到下一層。
神經(jīng)網(wǎng)絡(luò)為什么要使用激活函數(shù)?不同的激活函數(shù)有什么不同的作用?讀者可參考:
所以,影響神經(jīng)網(wǎng)絡(luò)表現(xiàn)能力的主要因素有神經(jīng)網(wǎng)絡(luò)的層數(shù)、神經(jīng)元的個數(shù)、神經(jīng)元之間的連接方式以及神經(jīng)元所采用的激活函數(shù)。神經(jīng)元之間以不同的連接方式(全連接、部分連接)組合,可以構(gòu)成不同神經(jīng)網(wǎng)絡(luò),對于不同的信號處理效果也不一樣。但是,目前依舊沒有一種通用的方法可以根據(jù)信號輸入的特征來決定神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu),這也是神經(jīng)網(wǎng)絡(luò)模型被稱為黑箱的原因之一,帶來的問題也就是模型的參數(shù)不容易調(diào)整,也不清楚其中到底發(fā)生了什么。因此,在不斷的探索當(dāng)中,前輩大牛們總結(jié)得到了許多經(jīng)典的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu):MLP、BP、FFNN、CNN、RNN等。詳細(xì)的介紹見以下鏈接:
神經(jīng)網(wǎng)絡(luò)優(yōu)點很明顯,給我們提供了構(gòu)建模型的便利,你大可不用顧及模型本身是如何作用的,只需要按照規(guī)則構(gòu)建網(wǎng)絡(luò),然后使用訓(xùn)練數(shù)據(jù)集不斷調(diào)整參數(shù),在許多問題上都能得到一個比較“能接受”的結(jié)果,然而我們對其中發(fā)生了什么是未可知的。在深度學(xué)習(xí)領(lǐng)域,許多問題都可以通過構(gòu)建深層的神經(jīng)網(wǎng)絡(luò)模型來解決。這里,我們不對神經(jīng)網(wǎng)絡(luò)的優(yōu)點做過多闡述。
1.2.2 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的缺陷
從傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)我們可以看出,信號流從輸入層到輸出層依次流過,同一層級的神經(jīng)元之間,信號是不會相互傳遞的。這樣就會導(dǎo)致一個問題,輸出信號只與輸入信號有關(guān),而與輸入信號的先后順序無關(guān)。并且神經(jīng)元本身也不具有存儲信息的能力,整個網(wǎng)絡(luò)也就沒有“記憶”能力,當(dāng)輸入信號是一個跟時間相關(guān)的信號時,如果我們想要通過這段信號的“上下文”信息來理解一段時間序列的意思,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)就顯得無力了。與我們?nèi)祟惖睦斫膺^程類似,我們聽到一句話時往往需要通過這句話中詞語出現(xiàn)的順序以及我們之前所學(xué)的關(guān)于這些詞語的意思來理解整段話的意思,而不是簡單的通過其中的幾個詞語來理解。
例如,在自然語言處理領(lǐng)域,我們要讓神經(jīng)網(wǎng)絡(luò)理解這樣一句話:“地球上最高的山是珠穆朗瑪峰”,按照傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),它可能會將這句話拆分為幾個單獨的詞(地球、上、最高的、山、是、珠穆朗瑪峰),分別輸入到模型之中,而不管這幾個詞之間的順序。然而,直觀上我們可以看到,這幾個詞出現(xiàn)的順序是與最終這句話要表達(dá)的意思是密切相關(guān)的,但傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)無法處理這種情況。
因此,我們需要構(gòu)建具有“記憶”能力的神經(jīng)網(wǎng)絡(luò)模型,用來處理需要理解上下文意思的信號,也就是時間序列數(shù)據(jù)。循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)就是用來處理這類信號的,RNN之所以能夠有效的處理時間序列數(shù)據(jù),主要是基于它比較特殊的運行原理。下面將介紹RNN的構(gòu)建過程和基本運行原理,然后引入長短期記憶網(wǎng)絡(luò)(LSTM)。
2. 循環(huán)神經(jīng)網(wǎng)絡(luò)RNN
2.1 RNN的構(gòu)造過程
RNN是一種特殊的神經(jīng)網(wǎng)路結(jié)構(gòu),其本身是包含循環(huán)的網(wǎng)絡(luò),允許信息在神經(jīng)元之間傳遞,如下圖所示:
RNN-rolled.png
圖示是一個RNN結(jié)構(gòu)示意圖,圖中的
表示神經(jīng)網(wǎng)絡(luò)模型,
表示模型的輸入信號,
表示模型的輸出信號,如果沒有
的輸出信號傳遞到
的那個箭頭, 這個網(wǎng)絡(luò)模型與普通的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)無異。那么這個箭頭做了什么事情呢?它允許
將信息傳遞給
,神經(jīng)網(wǎng)絡(luò)將自己的輸出作為輸入了!這怎么理解啊?作者第一次看到這個圖的時候也是有點懵,讀者可以思考一分鐘。
關(guān)鍵在于輸入信號是一個時間序列,跟時間
有關(guān)。也就是說,在
時刻,輸入信號
作為神經(jīng)網(wǎng)絡(luò)
的輸入,
的輸出分流為兩部分,一部分輸出給
,一部分作為一個隱藏的信號流被輸入到
中,在下一次時刻輸入信號
時,這部分隱藏的信號流也作為輸入信號輸入到了
中。此時神經(jīng)網(wǎng)絡(luò)
就同時接收了
時刻和
時刻的信號輸入了,此時的輸出信號又將被傳遞到下一時刻的
中。如果我們把上面那個圖根據(jù)時間
展開來看,就是:
RNN-unrolled.png
看到了嗎?
時刻的信息輸出給
時刻的模型
了,
時刻的信息輸出給
時刻的模型
了,
。這樣,相當(dāng)于RNN在時間序列上把自己復(fù)制了很多遍,每個模型都對應(yīng)一個時刻的輸入,并且當(dāng)前時刻的輸出還作為下一時刻的模型的輸入信號。
這樣鏈?zhǔn)降慕Y(jié)構(gòu)揭示了RNN本質(zhì)上是與序列相關(guān)的,是對于時間序列數(shù)據(jù)最自然的神經(jīng)網(wǎng)絡(luò)架構(gòu)。并且理論上,RNN可以保留以前任意時刻的信息。RNN在語音識別、自然語言處理、圖片描述、視頻圖像處理等領(lǐng)域已經(jīng)取得了一定的成果,而且還將更加大放異彩。在實際使用的時候,用得最多的一種RNN結(jié)構(gòu)是LSTM,為什么是LSTM呢?我們從普通RNN的局限性說起。
2.2 RNN的局限性
RNN利用了神經(jīng)網(wǎng)絡(luò)的“內(nèi)部循環(huán)”來保留時間序列的上下文信息,可以使用過去的信號數(shù)據(jù)來推測對當(dāng)前信號的理解,這是非常重要的進(jìn)步,并且理論上RNN可以保留過去任意時刻的信息。但實際使用RNN時往往遇到問題,請看下面這個例子。
假如我們構(gòu)造了一個語言模型,可以通過當(dāng)前這一句話的意思來預(yù)測下一個詞語。現(xiàn)在有這樣一句話:“我是一個中國人,出生在普通家庭,我最常說漢語,也喜歡寫漢字。我喜歡媽媽做的菜”。我們的語言模型在預(yù)測“我最常說漢語”的“漢語”這個詞時,它要預(yù)測“我最長說”這后面可能跟的是一個語言,可能是英語,也可能是漢語,那么它需要用到第一句話的“我是中國人”這段話的意思來推測我最常說漢語,而不是英語、法語等。而在預(yù)測“我喜歡媽媽做的菜”的最后的詞“菜”時并不需要“我是中國人”這個信息以及其他的信息,它跟我是不是一個中國人沒有必然的關(guān)系。
這個例子告訴我們,想要精確地處理時間序列,有時候我們只需要用到最近的時刻的信息。例如預(yù)測“我喜歡媽媽做的菜”最后這個詞“菜”,此時信息傳遞是這樣的:
RNN-shorttermdepdencies.png
“菜”這個詞與“我”、“喜歡”、“媽媽”、“做”、“的”這幾個詞關(guān)聯(lián)性比較大,距離也比較近,所以可以直接利用這幾個詞進(jìn)行最后那個詞語的推測。
而有時候我們又需要用到很早以前時刻的信息,例如預(yù)測“我最常說漢語”最后的這個詞“漢語”。此時信息傳遞是這樣的:
RNN-longtermdependencies.png
此時,我們要預(yù)測“漢語”這個詞,僅僅依靠“我”、“最”、“常”、“說”這幾個詞還不能得出我說的是漢語,必須要追溯到更早的句子“我是一個中國人”,由“中國人”這個詞語來推測我最常說的是漢語。因此,這種情況下,我們想要推測“漢語”這個詞的時候就比前面那個預(yù)測“菜”這個詞所用到的信息就處于更早的時刻。
而RNN雖然在理論上可以保留所有歷史時刻的信息,但在實際使用時,信息的傳遞往往會因為時間間隔太長而逐漸衰減,傳遞一段時刻以后其信息的作用效果就大大降低了。因此,普通RNN對于信息的長期依賴問題沒有很好的處理辦法。
為了克服這個問題,Hochreiter等人在1997年改進(jìn)了RNN,提出了一種特殊的RNN模型——LSTM網(wǎng)絡(luò),可以學(xué)習(xí)長期依賴信息,在后面的20多年被改良和得到了廣泛的應(yīng)用,并且取得了極大的成功。
3. 長短時間記憶網(wǎng)絡(luò)(LSTM)
3.1 LSTM與RNN的關(guān)系
長短期記憶(Long Short Term Memory,LSTM)網(wǎng)絡(luò)是一種特殊的RNN模型,其特殊的結(jié)構(gòu)設(shè)計使得它可以避免長期依賴問題,記住很早時刻的信息是LSTM的默認(rèn)行為,而不需要專門為此付出很大代價。
普通的RNN模型中,其重復(fù)神經(jīng)網(wǎng)絡(luò)模塊的鏈?zhǔn)侥P腿缦聢D所示,這個重復(fù)的模塊只有一個非常簡單的結(jié)構(gòu),一個單一的神經(jīng)網(wǎng)絡(luò)層(例如tanh層),這樣就會導(dǎo)致信息的處理能力比較低。
LSTM3-SimpleRNN.png
而LSTM在此基礎(chǔ)上將這個結(jié)構(gòu)改進(jìn)了,不再是單一的神經(jīng)網(wǎng)絡(luò)層,而是4個,并且以一種特殊的方式進(jìn)行交互。
LSTM3-chain.png
粗看起來,這個結(jié)構(gòu)有點復(fù)雜,不過不用擔(dān)心,接下來我們會慢慢解釋。在解釋這個神經(jīng)網(wǎng)絡(luò)層時我們先來認(rèn)識一些基本的模塊表示方法。圖中的模塊分為以下幾種:
LSTM2-notation.png
黃色方塊:表示一個神經(jīng)網(wǎng)絡(luò)層(Neural Network Layer);
粉色圓圈:表示按位操作或逐點操作(pointwise operation),例如向量加和、向量乘積等;
單箭頭:表示信號傳遞(向量傳遞);
合流箭頭:表示兩個信號的連接(向量拼接);
分流箭頭:表示信號被復(fù)制后傳遞到2個不同的地方。
下面我們將分別介紹這些模塊如何在LSTM中作用。
3.2 LSTM的基本思想
LSTM的關(guān)鍵是細(xì)胞狀態(tài)(直譯:cell state),表示為
,用來保存當(dāng)前LSTM的狀態(tài)信息并傳遞到下一時刻的LSTM中,也就是RNN中那根“自循環(huán)”的箭頭。當(dāng)前的LSTM接收來自上一個時刻的細(xì)胞狀態(tài)
,并與當(dāng)前LSTM接收的信號輸入
共同作用產(chǎn)生當(dāng)前LSTM的細(xì)胞狀態(tài)
,具體的作用方式下面將詳細(xì)介紹。
LSTM3-C-line.png
在LSTM中,采用專門設(shè)計的“門”來引入或者去除細(xì)胞狀態(tài)
中的信息。門是一種讓信息選擇性通過的方法。有的門跟信號處理中的濾波器有點類似,允許信號部分通過或者通過時被門加工了;有的門也跟數(shù)字電路中的邏輯門類似,允許信號通過或者不通過。這里所采用的門包含一個
神經(jīng)網(wǎng)絡(luò)層和一個按位的乘法操作,如下圖所示:
LSTM3-gate.png
其中黃色方塊表示
神經(jīng)網(wǎng)絡(luò)層,粉色圓圈表示按位乘法操作。
神經(jīng)網(wǎng)絡(luò)層可以將輸入信號轉(zhuǎn)換為
到
之間的數(shù)值,用來描述有多少量的輸入信號可以通過。
表示“不允許任何量通過”,
表示“允許所有量通過”。
神經(jīng)網(wǎng)絡(luò)層起到類似下圖的
函數(shù)所示的作用:
sigmod_function.jpg
其中,橫軸表示輸入信號,縱軸表示經(jīng)過
以后的輸出信號。
LSTM主要包括三個不同的門結(jié)構(gòu):遺忘門、記憶門和輸出門。這三個門用來控制LSTM的信息保留和傳遞,最終反映到細(xì)胞狀態(tài)
和輸出信號
。如下圖所示:
LSTM_gates.png
圖中標(biāo)示了LSTM中各個門的構(gòu)成情況和相互之間的關(guān)系,其中:
遺忘門由一個
神經(jīng)網(wǎng)絡(luò)層和一個按位乘操作構(gòu)成;
記憶門由輸入門(input gate)與tanh神經(jīng)網(wǎng)絡(luò)層和一個按位乘操作構(gòu)成;
輸出門(output gate)與
函數(shù)(注意:這里不是
神經(jīng)網(wǎng)絡(luò)層)以及按位乘操作共同作用將細(xì)胞狀態(tài)和輸入信號傳遞到輸出端。
3.3 遺忘門
顧名思義,遺忘門的作用就是用來“忘記”信息的。在LSTM的使用過程中,有一些信息不是必要的,因此遺忘門的作用就是用來選擇這些信息并“忘記”它們。遺忘門決定了細(xì)胞狀態(tài)
中的哪些信息將被遺忘。那么遺忘門的工作原理是什么呢?看下面這張圖。
LSTM3-focus-f.png
左邊高亮的結(jié)構(gòu)就是遺忘門了,包含一個
神經(jīng)網(wǎng)絡(luò)層(黃色方框,神經(jīng)網(wǎng)絡(luò)參數(shù)為
),接收
時刻的輸入信號
和
時刻LSTM的上一個輸出信號
,這兩個信號進(jìn)行拼接以后共同輸入到
神經(jīng)網(wǎng)絡(luò)層中,然后輸出信號
,
是一個
到
之間的數(shù)值,并與
相乘來決定
中的哪些信息將被保留,哪些信息將被舍棄。可能看到這里有的初學(xué)者還是不知道具體是什么意思,我們用一個簡單的例子來說明。
假設(shè)
,
,
, 那么遺忘門的輸入信號就是
和
的組合,即
, 然后通過
神經(jīng)網(wǎng)絡(luò)層輸出每一個元素都處于
到
之間的向量
,注意,此時
是一個與
維數(shù)相同的向量,此處為3維。如果看到這里還沒有看懂的讀者,可能會有這樣的疑問:輸入信號明明是6維的向量,為什么
就變成了3維呢?這里可能是將
神經(jīng)網(wǎng)絡(luò)層當(dāng)成了
激活函數(shù)了,兩者不是一個東西,初學(xué)者在這里很容易混淆。下文所提及的
神經(jīng)網(wǎng)絡(luò)層和
神經(jīng)網(wǎng)絡(luò)層而是類似的道理,他們并不是簡單的
激活函數(shù)和
激活函數(shù),在學(xué)習(xí)時要注意區(qū)分。
3.4 記憶門
記憶門的作用與遺忘門相反,它將決定新輸入的信息
和
中哪些信息將被保留。
LSTM3-focus-i.png
如圖所示,記憶門包含2個部分。第一個是包含
神經(jīng)網(wǎng)絡(luò)層(輸入門,神經(jīng)網(wǎng)絡(luò)網(wǎng)絡(luò)參數(shù)為
)和一個
神經(jīng)網(wǎng)絡(luò)層(神經(jīng)網(wǎng)絡(luò)參數(shù)為
)。
神經(jīng)網(wǎng)絡(luò)層的作用很明顯,跟遺忘門一樣,它接收
和
作為輸入,然后輸出一個
到
之間的數(shù)值
來決定哪些信息需要被更新;
Tanh神經(jīng)網(wǎng)絡(luò)層的作用是將輸入的
和
整合,然后通過一個
神經(jīng)網(wǎng)絡(luò)層來創(chuàng)建一個新的狀態(tài)候選向量
,
的值范圍在
到
之間。
記憶門的輸出由上述兩個神經(jīng)網(wǎng)絡(luò)層的輸出決定,
與
相乘來選擇哪些信息將被新加入到
時刻的細(xì)胞狀態(tài)
中。
3.5 更新細(xì)胞狀態(tài)
有了遺忘門和記憶門,我們就可以更新細(xì)胞狀態(tài)
了。
LSTM3-focus-C.png
這里將遺忘門的輸出
與上一時刻的細(xì)胞狀態(tài)
相乘來選擇遺忘和保留一些信息,將記憶門的輸出與從遺忘門選擇后的信息加和得到新的細(xì)胞狀態(tài)
。這就表示
時刻的細(xì)胞狀態(tài)
已經(jīng)包含了此時需要丟棄的
時刻傳遞的信息和
時刻從輸入信號獲取的需要新加入的信息
。
將繼續(xù)傳遞到
時刻的LSTM網(wǎng)絡(luò)中,作為新的細(xì)胞狀態(tài)傳遞下去。
3.6 輸出門
前面已經(jīng)講了LSTM如何來更新細(xì)胞狀態(tài)
, 那么在
時刻我們輸入信號
以后,對應(yīng)的輸出信號該如何計算呢?
LSTM3-focus-o.png
如上面左圖所示,輸出門就是將
時刻傳遞過來并經(jīng)過了前面遺忘門與記憶門選擇后的細(xì)胞狀態(tài)
, 與
時刻的輸出信號
和
時刻的輸入信號
整合到一起作為當(dāng)前時刻的輸出信號。整合的過程如上圖所示,
和
經(jīng)過一個
神經(jīng)網(wǎng)絡(luò)層(神經(jīng)網(wǎng)絡(luò)參數(shù)為
)輸出一個
到
之間的數(shù)值
。
經(jīng)過一個
函數(shù)(注意:這里不是
神經(jīng)網(wǎng)絡(luò)層)到一個在
到
之間的數(shù)值,并與
相乘得到輸出信號
,同時
也作為下一個時刻的輸入信號傳遞到下一階段。
其中,
函數(shù)是激活函數(shù)的一種,函數(shù)圖像為:
tanh.png
至此,基本的LSTM網(wǎng)絡(luò)模型就介紹完了。如果對LSTM模型還沒有理解到的,可以看一下這個視頻,作者是一個外國小哥,英文講解的,有動圖,方便理解。
3.7 LSTM的一些變體
前面已經(jīng)介紹了基本的LSTM網(wǎng)絡(luò)模型,而實際應(yīng)用時,我們常常會采用LSTM的一些變體,雖然差異不大,這里不再做詳細(xì)介紹,有興趣的讀者可以自行了解。
3.7.1 在門上增加窺視孔
LSTM3-var-peepholes.png
這是2000年Gers和Schemidhuber教授提出的一種LSTM變體。圖中,在傳統(tǒng)的LSTM結(jié)構(gòu)基礎(chǔ)上,每個門(遺忘門、記憶門和輸出門)增加了一個“窺視孔”(Peephole),有的學(xué)者在使用時也選擇只對部分門加入窺視孔。
3.7.2 整合遺忘門和輸入門
LSTM3-var-tied.png
與傳統(tǒng)的LSTM不同的是,這個變體不需要分開來確定要被遺忘和記住的信息,采用一個結(jié)構(gòu)搞定。在遺忘門的輸出信號值(
到
之間)上,用
減去該數(shù)值來作為記憶門的狀態(tài)選擇,表示只更新需要被遺忘的那些信息的狀態(tài)。
3.7.3 GRU
改進(jìn)比較大的一個LSTM變體叫Gated Recurrent Unit (GRU),目前應(yīng)用較多。結(jié)構(gòu)圖如下
LSTM3-var-GRU.png
GRU主要包含2個門:重置門和更新門。GRU混合了細(xì)胞狀態(tài)
和隱藏狀態(tài)
為一個新的狀態(tài),使用
來表示。 該模型比傳統(tǒng)的標(biāo)準(zhǔn)LSTM模型簡單。
4. 基于Pytorch的LSTM代碼實現(xiàn)
Pytorch是Python的一個機(jī)器學(xué)習(xí)包,與Tensorflow類似,Pytorch非常適合用來構(gòu)建神經(jīng)網(wǎng)絡(luò)模型,并且已經(jīng)提供了一些常用的神經(jīng)網(wǎng)絡(luò)模型包,用戶可以直接調(diào)用。下面我們就用一個簡單的小例子來說明如何使用Pytorch來構(gòu)建LSTM模型。
我們使用正弦函數(shù)和余弦函數(shù)來構(gòu)造時間序列,而正余弦函數(shù)之間是成導(dǎo)數(shù)關(guān)系,所以我們可以構(gòu)造模型來學(xué)習(xí)正弦函數(shù)與余弦函數(shù)之間的映射關(guān)系,通過輸入正弦函數(shù)的值來預(yù)測對應(yīng)的余弦函數(shù)的值。
正弦函數(shù)和余弦函數(shù)對應(yīng)關(guān)系圖如下圖所示:
demo_sine_cosine.png
可以看到,每一個函數(shù)曲線上,每一個正弦函數(shù)的值都對應(yīng)一個余弦函數(shù)值。但其實如果只關(guān)心正弦函數(shù)的值本身而不考慮當(dāng)前值所在的時間,那么正弦函數(shù)值和余弦函數(shù)值不是一一對應(yīng)關(guān)系。例如,當(dāng)
和
時,
,但在這兩個不同的時刻,
的值卻不一樣,也就是說如果不考慮時間,同一個正弦函數(shù)值可能對應(yīng)了不同的幾個余弦函數(shù)值。對于傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)來說,它僅僅基于當(dāng)前的輸入來預(yù)測輸出,對于這種同一個輸入可能對應(yīng)多個輸出的情況不再適用。
我們?nèi)≌液瘮?shù)的值作為LSTM的輸入,來預(yù)測余弦函數(shù)的值。基于Pytorch來構(gòu)建LSTM模型,采用1個輸入神經(jīng)元,1個輸出神經(jīng)元,16個隱藏神經(jīng)元作為LSTM網(wǎng)絡(luò)的構(gòu)成參數(shù),平均絕對誤差(LMSE)作為損失誤差,使用Adam優(yōu)化算法來訓(xùn)練LSTM神經(jīng)網(wǎng)絡(luò)。基于Anaconda和Python3.6的完整代碼如下:
# -*- coding:UTF-8 -*-
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
# Define LSTM Neural Networks
class LstmRNN(nn.Module):
"""
Parameters:
- input_size: feature size
- hidden_size: number of hidden units
- output_size: number of output
- num_layers: layers of LSTM to stack
"""
def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers) # utilize the LSTM model in torch.nn
self.forwardCalculation = nn.Linear(hidden_size, output_size)
def forward(self, _x):
x, _ = self.lstm(_x) # _x is input, size (seq_len, batch, input_size)
s, b, h = x.shape # x is output, size (seq_len, batch, hidden_size)
x = x.view(s*b, h)
x = self.forwardCalculation(x)
x = x.view(s, b, -1)
return x
if __name__ == '__main__':
# create database
data_len = 200
t = np.linspace(0, 12*np.pi, data_len)
sin_t = np.sin(t)
cos_t = np.cos(t)
dataset = np.zeros((data_len, 2))
dataset[:,0] = sin_t
dataset[:,1] = cos_t
dataset = dataset.astype('float32')
# plot part of the original dataset
plt.figure()
plt.plot(t[0:60], dataset[0:60,0], label='sin(t)')
plt.plot(t[0:60], dataset[0:60,1], label = 'cos(t)')
plt.plot([2.5, 2.5], [-1.3, 0.55], 'r--', label='t = 2.5') # t = 2.5
plt.plot([6.8, 6.8], [-1.3, 0.85], 'm--', label='t = 6.8') # t = 6.8
plt.xlabel('t')
plt.ylim(-1.2, 1.2)
plt.ylabel('sin(t) and cos(t)')
plt.legend(loc='upper right')
# choose dataset for training and testing
train_data_ratio = 0.5 # Choose 80% of the data for testing
train_data_len = int(data_len*train_data_ratio)
train_x = dataset[:train_data_len, 0]
train_y = dataset[:train_data_len, 1]
INPUT_FEATURES_NUM = 1
OUTPUT_FEATURES_NUM = 1
t_for_training = t[:train_data_len]
# test_x = train_x
# test_y = train_y
test_x = dataset[train_data_len:, 0]
test_y = dataset[train_data_len:, 1]
t_for_testing = t[train_data_len:]
# ----------------- train -------------------
train_x_tensor = train_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5
train_y_tensor = train_y.reshape(-1, 5, OUTPUT_FEATURES_NUM) # set batch size to 5
# transfer data to pytorch tensor
train_x_tensor = torch.from_numpy(train_x_tensor)
train_y_tensor = torch.from_numpy(train_y_tensor)
# test_x_tensor = torch.from_numpy(test_x)
lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=1) # 16 hidden units
print('LSTM model:', lstm_model)
print('model.parameters:', lstm_model.parameters)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-2)
max_epochs = 10000
for epoch in range(max_epochs):
output = lstm_model(train_x_tensor)
loss = loss_function(output, train_y_tensor)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if loss.item() < 1e-4:
print('Epoch [{}/{}], Loss: {:.5f}'.format(epoch+1, max_epochs, loss.item()))
print("The loss value is reached")
break
elif (epoch+1) % 100 == 0:
print('Epoch: [{}/{}], Loss:{:.5f}'.format(epoch+1, max_epochs, loss.item()))
# prediction on training dataset
predictive_y_for_training = lstm_model(train_x_tensor)
predictive_y_for_training = predictive_y_for_training.view(-1, OUTPUT_FEATURES_NUM).data.numpy()
# torch.save(lstm_model.state_dict(), 'model_params.pkl') # save model parameters to files
# ----------------- test -------------------
# lstm_model.load_state_dict(torch.load('model_params.pkl')) # load model parameters from files
lstm_model = lstm_model.eval() # switch to testing model
# prediction on test dataset
test_x_tensor = test_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5, the same value with the training set
test_x_tensor = torch.from_numpy(test_x_tensor)
predictive_y_for_testing = lstm_model(test_x_tensor)
predictive_y_for_testing = predictive_y_for_testing.view(-1, OUTPUT_FEATURES_NUM).data.numpy()
# ----------------- plot -------------------
plt.figure()
plt.plot(t_for_training, train_x, 'g', label='sin_trn')
plt.plot(t_for_training, train_y, 'b', label='ref_cos_trn')
plt.plot(t_for_training, predictive_y_for_training, 'y--', label='pre_cos_trn')
plt.plot(t_for_testing, test_x, 'c', label='sin_tst')
plt.plot(t_for_testing, test_y, 'k', label='ref_cos_tst')
plt.plot(t_for_testing, predictive_y_for_testing, 'm--', label='pre_cos_tst')
plt.plot([t[train_data_len], t[train_data_len]], [-1.2, 4.0], 'r--', label='separation line') # separation line
plt.xlabel('t')
plt.ylabel('sin(t) and cos(t)')
plt.xlim(t[0], t[-1])
plt.ylim(-1.2, 4)
plt.legend(loc='upper right')
plt.text(14, 2, "train", size = 15, alpha = 1.0)
plt.text(20, 2, "test", size = 15, alpha = 1.0)
plt.show()
訓(xùn)練的過程如下:
LSTM_training.png
該模型在訓(xùn)練集和測試集上的結(jié)果如下:
demo_LSTM.png
圖中,紅色虛線的左邊表示該模型在訓(xùn)練數(shù)據(jù)集上的表現(xiàn),右邊表示該模型在測試數(shù)據(jù)集上的表現(xiàn)。可以看到,使用LSTM構(gòu)建訓(xùn)練模型,我們可以僅僅使用正弦函數(shù)在
時刻的值作為輸入來準(zhǔn)確預(yù)測
時刻的余弦函數(shù)值,不用額外添加當(dāng)前的時間信息、速度信息等。
5. 參考鏈接
總結(jié)
以上是生活随笔為你收集整理的控制论python_[干货]深入浅出LSTM及其Python代码实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 微信小程序没登录跳到登录页怎么做_微信小
- 下一篇: r型聚类典型指标_六种GAN评估指标的综