tensorflow学习(2.网络模型的存储以及提取)
生活随笔
收集整理的這篇文章主要介紹了
tensorflow学习(2.网络模型的存储以及提取)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
第一篇學習了CNN網絡的構建以及代碼的基礎結構,第二篇則是實際項目過程中需要的網絡模型的存儲
先放上存儲的代碼:
#tf可以認為是全局變量,從該變量為類,從中取input_data變量 import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf import sys #讀取數據集 mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) """ #softmax方法進行訓練 #這里是變量的占位符,一般是輸入輸出使用該部分 x=tf.placeholder(tf.float32,[None,784]) y_=tf.placeholder("float",[None,10])#定義參數變量 W=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) y=tf.nn.softmax(tf.matmul(x,W)+b)#評價函數 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#啟動模型,Session建立這樣一個對象,然后指定某種操作,并實際進行該步 init=tf.initialize_all_variables() sess=tf.Session() sess.run(init)#數據讀取部分 for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#run第一個參數是fetch,可以是tensor也可以是Operation,第二個feed_dict是替換tensor的值sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})print(batch_xs,batch_ys,i)correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))"""#這里用CNN方法進行訓練 #函數定義部分 def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#隨機權重賦值,不過truncated_normal代表如果是2倍標準差之外的結果重新選取該值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置項return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示輸出補邊,這里輸出與輸入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范圍的大小,stride為掃描步長# 這里是變量的占位符,一般是輸入輸出使用該部分 x=tf.placeholder(tf.float32,[None,784]) y_=tf.placeholder("float",[None,10]) x_image=tf.reshape(x,[-1,28,28,1])#-1表示自動計算該維度 #建立第一層 W_conv1=weight_variable([5,5,1,32]) b_conv1=bias_variable([32]) h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1) h_pool1=max_pool_2x2(h_conv1) #第二層 W_conv2=weight_variable([5,5,32,64]) b_conv2=bias_variable([64]) h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2) h_pool2=max_pool_2x2(h_conv2)#第三層,而且這里是全連接層 W_fc1=weight_variable([7*7*64,1024]) b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64]) h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1) #dropout,注意這里也是有一個輸入參數的,和x以及y一樣 keep_prob=tf.placeholder(tf.float32) h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10]) b_fc2=bias_variable([10]) y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 評價函數 cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv)) train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 啟動模型,Session建立這樣一個對象,然后指定某種操作,并實際進行該步 init=tf.initialize_all_variables() sess=tf.Session() sess.run(init)#數據讀取部分 for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#這里貌似是代表讀取50張圖像數據#run第一個參數是fetch,可以是tensor也可以是Operation,第二個feed_dict是替換tensor的值'''if i % 10 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})print("step:%d,accuracy:%g" % (i, train_accuracy))'''sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})#sess.run第一個參數是想要運行的位置,一般有train,accuracy,initdeng#第二個參數feed_dict,一般是輸入參數,該代碼里有x,y以及drop的參數if i%20==0 :print(i)print("train accuracy:%g"%sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})) print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))#保存模型 model_path='MNIST_model/simple_mnist.ckpt' saver=tf.train.Saver() saver_path=saver.save(sess,model_path) print("model saved in file:", saver_path)?
保存的代碼實際上只有后半部分,前面的代碼是第一篇中講到的。第一篇鏈接:https://blog.csdn.net/qq_26499769/article/details/82896046
運行代碼的結果如下:
讀取的代碼如下:
import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf import sys #讀取數據集 mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#前半部分主要以網絡的構建為主 #這里用CNN方法進行訓練 #函數定義部分 def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#隨機權重賦值,不過truncated_normal代表如果是2倍標準差之外的結果重新選取該值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置項return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示輸出補邊,這里輸出與輸入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范圍的大小,stride為掃描步長# 這里是變量的占位符,一般是輸入輸出使用該部分 x=tf.placeholder(tf.float32,[None,784]) y_=tf.placeholder("float",[None,10]) x_image=tf.reshape(x,[-1,28,28,1])#-1表示自動計算該維度 #建立第一層 W_conv1=weight_variable([5,5,1,32]) b_conv1=bias_variable([32]) h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1) h_pool1=max_pool_2x2(h_conv1) #第二層 W_conv2=weight_variable([5,5,32,64]) b_conv2=bias_variable([64]) h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2) h_pool2=max_pool_2x2(h_conv2)#第三層,而且這里是全連接層 W_fc1=weight_variable([7*7*64,1024]) b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64]) h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1) #dropout,注意這里也是有一個輸入參數的,和x以及y一樣 keep_prob=tf.placeholder(tf.float32) h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10]) b_fc2=bias_variable([10]) y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 評價函數 cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv)) train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 啟動模型,Session建立這樣一個對象,然后指定某種操作,并實際進行該步sess=tf.Session()#后半部分,進行參數的下載 #模型下載,(新人,可能理解錯誤,網絡還是需要先定義好,然后進行參數的下載,對于自己的網絡這樣的方法沒有問題,但是他人的網絡在不知道具體的網絡時,沒辦法通過下載去復現網絡模型) saver=tf.train.Saver() saver.restore(sess,"MNIST_model/simple_mnist.ckpt") print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))讀取代碼也只有最后的一部分,運行結果如下:
這是之前CNN的訓練結果,可以看到完整的下載了參數。
有一個說明的很好的教程:https://www.bilibili.com/video/av16001891/?p=29
總結
以上是生活随笔為你收集整理的tensorflow学习(2.网络模型的存储以及提取)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow学习(1.CNN简单
- 下一篇: tensorflow学习(3.tenso