TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类
生活随笔
收集整理的這篇文章主要介紹了
TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
TF之LSTM:利用多層LSTM算法對MNIST手寫數字識別數據集進行多分類
?
?
目錄
設計思路
實現代碼
?
?
設計思路
更新……
?
?
實現代碼
# -*- coding:utf-8 -*- import tensorflow as tf import numpy as np from tensorflow.contrib import rnn from tensorflow.examples.tutorials.mnist import input_data#根據電腦情況設置 GPU config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config)# 1、定義數據集 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) print(mnist.train.images.shape)#2、定義模型超參數 lr = 1e-3 # batch_size = 128 batch_size = tf.placeholder(tf.int32) #采用占位符的方式,因為在訓練和測試的時候要用不同的batch_size。注意類型必須為 tf.int32 input_size = 28 # 每個時刻的輸入特征是28維的,就是每個時刻輸入一行,一行有 28 個像素 timestep_size = 28 # 時序持續長度為28,即每做一次預測,需要先輸入28行 hidden_size = 256 # 每個隱含層的節點數 layer_num = 2 # LSTM layer 的層數 class_num = 10 # 最后輸出分類類別數量,如果是回歸預測的話應該是 1_X = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, class_num]) keep_prob = tf.placeholder(tf.float32)#3、LSTM模型的搭建、訓練、測試 #3.1、LSTM模型的搭建 X = tf.reshape(_X, [-1, 28, 28]) #RNN 的輸入shape = (batch_size, timestep_size, input_size),把784個點的字符信息還原成 28 * 28 的圖片 lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True) #定義一層 LSTM_cell,只需要說明 hidden_size, 它會自動匹配輸入的 X 的維度 lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) #添加 dropout layer, 一般只設置 output_keep_prob mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True) #調用 MultiRNNCell來實現多層 LSTM init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32) #用全零來初始化state#3.2、LSTM模型的運行:構建好的網絡運行起來 #T1、調用 dynamic_rnn()法 # ** 當 time_major==False 時, outputs.shape = [batch_size, timestep_size, hidden_size],所以,可以取 h_state = outputs[:, -1, :] 作為最后輸出 # ** state.shape = [layer_num, 2, batch_size, hidden_size],或者,可以取 h_state = state[-1][1] 作為最后輸出,最后輸出維度是 [batch_size, hidden_size] # outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False) # h_state = outputs[:, -1, :] # 或者 h_state = state[-1][1]#T2、自定義LSTM迭代按時間步展開計算:為了更好的理解 LSTM 工作原理把T1的函數自己來實現 #(1)、可以采用RNNCell的 __call__()函數,來實現LSTM按時間步迭代。 outputs = list() state = init_state with tf.variable_scope('RNN'):for timestep in range(timestep_size):if timestep > 0:tf.get_variable_scope().reuse_variables()(cell_output, state) = mlstm_cell(X[:, timestep, :], state) # 這里的state保存了每一層 LSTM 的狀態outputs.append(cell_output) h_state = outputs[-1]#3.3、LSTM模型的訓練 # 定義 softmax 的連接權重矩陣和偏置:上面 LSTM 部分的輸出會是一個 [hidden_size] 的tensor,我們要分類的話,還需要接一個 softmax 層 # out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights') # out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias') W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32) bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32) y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)#定義損失和評估函數 cross_entropy = -tf.reduce_mean(y * tf.log(y_pre)) train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))sess.run(tf.global_variables_initializer()) for i in range(2000):_batch_size = 128batch = mnist.train.next_batch(_batch_size)if (i+1)%200 == 0:train_accuracy = sess.run(accuracy, feed_dict={_X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})# 已經迭代完成的 epoch 數: mnist.train.epochs_completedprint("Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy))sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})# 計算測試數據的準確率 print("test accuracy %g"% sess.run(accuracy, feed_dict={_X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]}))?
參考文章:https://www.cnblogs.com/mfryf/p/7903958.html
?
?
?
總結
以上是生活随笔為你收集整理的TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: DL之模型调参:深度学习算法模型优化参数
- 下一篇: Paper:《Generating Se