tensorflow保存模型和加载模型的方法(Python和Android)
tensorflow保存模型和加載模型的方法(Python和Android)
一、tensorflow保存模型的幾種方法:
(1)?tf.train.saver()保存模型
? ? ?使用 tf.train.saver()保存模型,該方法保存模型文件的時候會產生多個文件,會把計算圖的結構和圖上參數取值分成了不同的文件存儲。這種方法是在TensorFlow中是最常用的保存方式。
? ? 例如:
import tensorflow as tf # 聲明兩個變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部變量 saver = tf.train.Saver() # 聲明tf.train.Saver類用于保存模型 with tf.Session() as sess:sess.run(init_op)print("v1:", sess.run(v1)) # 打印v1、v2的值一會讀取之后對比print("v2:", sess.run(v2))saver_path = saver.save(sess, "save/model.ckpt") # 將模型保存到save/model.ckpt文件print("Model saved in file:", saver_path)? ? 運行后,會在save目錄下保存了四個文件:
? ? 其中
- checkpoint是檢查點文件,文件保存了一個目錄下所有的模型文件列表;
- model.ckpt.meta文件保存了TensorFlow計算圖的結構,可以理解為神經網絡的網絡結構,該文件可以被 tf.train.import_meta_graph 加載到當前默認的圖來使用。
- ckpt.data : 保存模型中每個變量的取值
參考資料:
https://blog.csdn.net/michael_yt/article/details/74737489
https://blog.csdn.net/lwplwf/article/details/62419087
(2)tf.train.write_graph()
? ? 使用 tf.train.write_graph()保存模型,該方法只是保存了模型的結構,并不保存訓練完畢的參數值。
(3)convert_variables_to_constants固化模型結構
? ? 很多時候,我們需要將TensorFlow的模型導出為單個文件(同時包含模型結構的定義與權重),方便在其他地方使用(如在Android中部署網絡)。利用tf.train.write_graph()默認情況下只導出了網絡的定義(沒有權重),而利用tf.train.Saver().save()導出的文件graph_def與權重是分離的,因此需要采用別的方法。?我們知道,graph_def文件中沒有包含網絡中的Variable值(通常情況存儲了權重),但是卻包含了constant值,所以如果我們能把Variable轉換為constant,即可達到使用一個文件同時存儲網絡架構與權重的目標。
? ? TensoFlow為我們提供了convert_variables_to_constants()方法,該方法可以固化模型結構,將計算圖中的變量取值以常量的形式保存。而且保存的模型可以移植到Android平臺。
????參考資料:
????【1】https://blog.csdn.net/sinat_29957455/article/details/78511119
? ? 【2】這里主要實現第三種方法,因為該方法保存的模型可以移植到Android平臺運行。以下Python代碼,都共享在
? ? ??Github:https://github.com/PanJinquan/tensorflow-learning-tutorials/tree/master/MNIST-Demo;
? ? 【3】移植Android的詳細過程可參考本人的另一篇博客資料《將tensorflow MNIST訓練模型移植到Android》:
? ? ? ?https://blog.csdn.net/guyuealian/article/details/79672257
二、訓練和保存模型
? ? 以MNIST手寫數字識別為例,這里首先使用Python版的TensorFlow實現SoftMax Regression分類器,并將訓練好的模型的網絡拓撲結構和參數保存為pb文件,其中convert_variables_to_constants函數,會將計算圖中的變量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119
#coding=utf-8 from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf from tensorflow.python.framework import graph_util print('tensortflow:{0}'.format(tf.__version__))mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)#create model with tf.name_scope('input'):x = tf.placeholder(tf.float32,[None,784],name='x_input')#輸入節點名:x_inputy_ = tf.placeholder(tf.float32,[None,10],name='y_input') with tf.name_scope('layer'):with tf.name_scope('W'):#tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]W = tf.Variable(tf.zeros([784,10]),name='Weights')with tf.name_scope('b'):b = tf.Variable(tf.zeros([10]),name='biases')with tf.name_scope('W_p_b'):Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')y = tf.nn.softmax(Wx_plus_b, name='final_result')# 定義損失函數和優化方法 with tf.name_scope('loss'):loss = -tf.reduce_sum(y_ * tf.log(y)) with tf.name_scope('train_step'):train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)print(train_step) # 初始化 sess = tf.InteractiveSession() init = tf.global_variables_initializer() sess.run(init) # 訓練 for step in range(100):batch_xs,batch_ys =mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})# variables = tf.all_variables()# print(len(variables))# print(sess.run(b))# 測試模型準確率 pre_num=tf.argmax(y,1,output_type='int32',name="output")#輸出節點名:output correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32')) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}) print('測試正確率:{0}'.format(a))# 保存訓練好的模型 #形參output_node_names用于指定輸出的節點名稱,output_node_names=['output']對應pre_num=tf.argmax(y,1,name="output"), output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output']) with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表寫文件,b代表將數據以二進制方式寫入文件。f.write(output_graph_def.SerializeToString()) sess.close()# 注: # convert_variables_to_constants函數,會將計算圖中的變量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119?
三、加載和測試
批量測試:
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data from PIL import Image import matplotlib import matplotlib.pyplot as plt#模型路徑 model_path = 'model/mnist.pb' #測試數據 mnist = input_data.read_data_sets("Mnist_data/", one_hot=True) x_test = mnist.test.images x_labels = mnist.test.labels;with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")# 【1】下面是進行批量測試----------------------------------------------------------pre_num = sess.run(output, feed_dict={input_x: x_test})#利用訓練好的模型預測結果#結果批量測試的準確率correct_prediction = tf.equal(pre_num, tf.argmax(x_labels, 1,output_type='int32'))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))acc = sess.run(accuracy, feed_dict={input_x: x_test})# a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})print('測試正確率:{0}'.format(acc))#【2】下面是進行單張圖片的測試-----------------------------------------------------testImage=x_test[0];test_input = testImage.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用訓練好的模型預測結果print('模型預測結果為:',pre_num)#顯示測試的圖片testImage = testImage.reshape(28, 28)testImage=np.array(testImage * 255, dtype="int32")fig = plt.figure(), plt.imshow(testImage, cmap='binary') # 顯示圖片plt.title("prediction result:"+str(pre_num))plt.show()#保存測定的圖片testImage = Image.fromarray(testImage)testImage = testImage.convert('L')testImage.save("data/test_image.jpg")# matplotlib.image.imsave('data/name.jpg', im)sess.close()?
單個樣本測試:
?
import tensorflow as tf import numpy as np from PIL import Image import matplotlib.pyplot as plt#模型路徑 model_path = 'model/mnist.pb' #測試圖片 testImage = Image.open("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#對圖片進行測試testImage=testImage.convert('L')testImage = testImage.resize((28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用訓練好的模型預測結果print('模型預測結果為:',pre_num)#顯示測試的圖片# testImage = test_x.reshape(28, 28)fig = plt.figure(), plt.imshow(testImage,cmap='binary') # 顯示圖片plt.title("prediction result:"+str(pre_num))plt.show()讀取圖片進行測試:
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import cv2 as cv #模型路徑 model_path = 'model/mnist.pb' #測試圖片 testImage = cv.imread("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#對圖片進行測試testImage=cv.cvtColor(testImage, cv.COLOR_BGR2GRAY)testImage=cv.resize(testImage,dsize=(28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用訓練好的模型預測結果print('模型預測結果為:',pre_num)# cv.imshow("image",testImage)# cv.waitKey(0)#顯示測試的圖片fig = plt.figure(), plt.imshow(testImage,cmap='binary') # 顯示圖片plt.title("prediction result:"+str(pre_num))plt.show()源碼Github:https://github.com/PanJinquan/MNIST-TensorFlow-Python
上面TensorFlow保存訓練好的模型,可以移植到Android,詳細過程可以參考另一篇博客資料《將tensorflow MNIST訓練模型移植到Android》:https://blog.csdn.net/guyuealian/article/details/79672257
?
?
總結
以上是生活随笔為你收集整理的tensorflow保存模型和加载模型的方法(Python和Android)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 将tensorflow训练好的模型移植到
- 下一篇: 随时更新———个人喜欢的关于模式识别、机