用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)進行MNIST 手寫數字辨識
循環神經網絡RNN相比傳統的神經網絡在處理序列化數據時更有優勢,因為RNN能夠將加入上(下)文信息進行考慮。一個簡單的RNN如下圖所示:
將這個循環展開得到下圖:
上一時刻的狀態會傳遞到下一時刻。這種鏈式特性決定了RNN能夠很好的處理序列化的數據,RNN 在語音識別,語言建模,翻譯,圖片描述等問題上已經取得了很到的結果。
根據輸入、輸出的不同和是否有延遲等一些情況,RNN在應用中有如下一些形態:
RNN存在的問題
RNN能夠把狀態傳遞到下一時刻,好像對一部分信息有記憶能力一樣,如下圖:
h3
的值可能會由x1,x2的值來決定。
但是,對于一些復雜場景
由于距離太遠,中間間隔了太多狀態,x1,x2對ht+1
的值幾乎起不到任何作用。(梯度消失和梯度爆炸)
LSTM(Long Short Term Memory)
由于RNN不能很好地處理這種問題,于是出現了LSTM(Long Short Term Memory)一種加強版的RNN(LSTM可以改善梯度消失問題)。簡單來說就是原始RNN沒有長期的記憶能力,于是就給RNN加上了一些記憶控制器,實現對某些信息能夠較長期的記憶,而對某些信息只有短期記憶能力。
如上圖所示,LSTM中存在Forget Gate,Input Gate,Output Gate來控制信息的流動程度。
RNN:
LSTN:
加號圓圈表示線性相加,乘號圓圈表示用gate來過濾信息。
Understanding LSTM中對LSTM有非常詳細的介紹。(對應的中文翻譯)
LSTM MNIST手寫數字辨識
實際上,圖片文字識別這類任務用CNN來做效果更好,但是這里想要強行用LSTM來做一波。
MNIST_data中每一個image的大小是28*28,以行順序作為序列輸入,即第一行的28個像素作為$x_{0}
,第二行為
x_1,...,第28行的28個像素作為
x_28$輸入,一個網絡結構總共的輸入是28個維度為28的向量,輸出值是10維的向量,表示的是0-9個數字的概率值。這是一個many to one的RNN結構。
下面直接上代碼:
這里outputs,final_state = tf.nn.dynamic_rnn(...).
final_state包含兩個量,第一個為c保存了每個LSTM任務最后一個cell中每個神經元的狀態值,第二個量h保存了每個LSTM任務最后一個cell中每個神經元的輸出值,所以c和h的維度都是[BATCH_SIZE,NUM_UNITS]。
outputs的維度是[BATCH_SIZE,TIME_STEP,NUM_UNITS],保存了每個step中cell的輸出值h。
由于這里是一個many to one的任務,只需要最后一個step的輸出outputs[:, -1, :],output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES) 通過一個全連接層將輸出限制為N_CLASSES。
訓練過程輸出:
train loss: 2.2990 | test accuracy: 0.13 train loss: 0.1347 | test accuracy: 0.96 train loss: 0.0620 | test accuracy: 0.97 train loss: 0.0788 | test accuracy: 0.98 train loss: 0.0160 | test accuracy: 0.98 train loss: 0.0084 | test accuracy: 0.99 train loss: 0.0436 | test accuracy: 0.99 train loss: 0.0104 | test accuracy: 0.98 train loss: 0.0736 | test accuracy: 0.99 train loss: 0.0154 | test accuracy: 0.98 train loss: 0.0407 | test accuracy: 0.98 train loss: 0.0109 | test accuracy: 0.98 train loss: 0.0722 | test accuracy: 0.98 train loss: 0.1133 | test accuracy: 0.98 train loss: 0.0072 | test accuracy: 0.99 train loss: 0.0352 | test accuracy: 0.98可以看到,雖然RNN是擅長處理序列類的任務,在MNIST手寫數字圖片辨識這個任務上,RNN同樣可以取得很高的正確率。
參考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://yjango.gitbooks.io/superorganism/content/lstmgru.html
參考代碼
https://www.cnblogs.com/sandy-t/p/6930608.html
有些人,一輩子都沒有得到過自己想要的,因為他們總是半途而廢
總結
以上是生活随笔為你收集整理的用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python 格式化输出%和format
- 下一篇: cad和python哪个好学_对纯外行人