TensorFlow RNN tutorial解读
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow RNN tutorial解读
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
github鏈接
和其他代碼比起來,這個代碼的結構很不科學,只有一個主文件,model和train沒有分開……
參考鏈接:
tensorflow筆記:多層LSTM代碼分析
代碼分析:
下圖左邊是rolled版本,這種結構的反向傳播計算很困難。因此采用右邊這種結構,有num_step個x輸入。每次訓練數據輸入格式為[batch_size, num_steps](類比seq2seq中的encoder_input的shape [batch_size, encoder_size])。如下圖:
seq2seq是分為了兩個部分,encoder和decoder部分。在RNN中,只有encode,即輸入x,輸出o,不需要decoder_input部分。在本例中,輸入是[batch_size, num_steps]個的單詞預測概率(one-hot形式)。和[batch_size, num_steps]個target(數字形式)作比較。計算loss的函數是tf.contrib.seq2seq.sequence_loss
定義了模型之后,第二個graph負責給model feed并統計結果。
要fetch的內容有:model.cost,model.final_state, model.eval_op(train_op)
vals = session.run(fetches, feed_dict)
代碼結構:
reader.py中的兩個主要函數:
- ptb_raw_data把三個txt中的單詞都轉化成唯一的id,只保留最常見的10000個單詞。
- ptb_producer定義了input和target。格式為:batch_size*num_steps的二維矩陣。target是input右移1位后的結果。即通過前一個單詞預測后一個單詞。
ptb_word_lm.py中
首先定義了PTBModel結構,__init__函數定義了self.config, inputs, output, state, self.logits, self.cost, self.final_state, (self.learning_rate, self._train_op, self._new_lr, self._lr_update)這些是train才有的。
最后在main函數里
- 第一個graph里,分別構建learn\valid\test PTBmodel,保存在metagraph中。
- 第二個graph里,導入metagraph,對于三個模型分別run_epoch。
轉載于:https://www.cnblogs.com/yingtaomj/p/7777222.html
總結
以上是生活随笔為你收集整理的TensorFlow RNN tutorial解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python学习笔记之函数式编程
- 下一篇: TLS,SSL,HTTPS with P