简单的线性回归实现模型的存储和读取
生活随笔
收集整理的這篇文章主要介紹了
简单的线性回归实现模型的存储和读取
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
和這篇文章對比https://blog.csdn.net/fanzonghao/article/details/81023730
不希望重復定義圖上的運算,也就是在模型恢復過程中,不想sess.run(init)首先看路徑
lineRegulation_model.py定義線性回歸類:
import tensorflow as tf """ 類定義一些公共量,方便模型載入用 """ class LineRegModel:def __init__(self):with tf.variable_scope('var'):self.a_val=tf.Variable(tf.random_normal(shape=[1]),name='a_val')self.b_val = tf.Variable(tf.random_normal(shape=[1]),name='b_val')self.x_input=tf.placeholder(dtype=tf.float32,name='input_placeholder')self.y_label = tf.placeholder(dtype=tf.float32,name='result_placeholder')self.y_output = tf.add(tf.multiply(self.x_input,self.a_val),self.b_val,name='output')self.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))def get_saver(self):return tf.train.Saver()def get_op(self):return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)model_train.py定義模型訓練過程
import tensorflow as tf import numpy as np from save_and_restore2 import global_variable from save_and_restore2 import lineRegulation_model as model import os if not os.path.exists('./model'):os.makedirs('./model') """ 訓練模型 """ train_x=np.random.rand(5) train_y=train_x*5+3 model=model.LineRegModel()#類要加括號 a_val=model.a_val b_val=model.b_val x_input=model.x_input y_label=model.y_label y_output=model.y_output loss=model.loss optimizer=model.get_op() saver=model.get_saver() if __name__ == '__main__':init=tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)flag=Trueepoch=0while flag:epoch+=1cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})if cost<1e-6:flag=Falseprint('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))print('epoch={}'.format(epoch))print(a_val)# print(a_val.op)saver.save(sess,global_variable.save_path)print('model save finish')print(a_val)的形式
print(a_val.op)的形式
model_restore.py恢復模型 ,利用恢復圖在恢復權重的方式,可實現更細節的模型恢復
import tensorflow as tf from save_and_restore import global_variable,lineRegulation_model as model """ 恢復模型圖文件 """ saver=tf.train.import_meta_graph('./model/weight.meta') #讀取placeholder和最終的輸出結果 graph=tf.get_default_graph() a_val=graph.get_tensor_by_name('var/a_val:0') b_val=graph.get_tensor_by_name('var/b_val:0')input_placeholder=graph.get_tensor_by_name('input_placeholder:0') labels_placeholder=graph.get_tensor_by_name('result_placeholder:0') y_output=graph.get_tensor_by_name('output:0')with tf.Session() as sess:#具體權重的恢復saver.restore(sess,'./model/weight')result=sess.run(y_output,feed_dict={input_placeholder:[1]})print(result)print(sess.run(a_val))print(sess.run(b_val))?
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的简单的线性回归实现模型的存储和读取的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Animation Property
- 下一篇: 什么是图像