Tensorflow Lite之编译生成tflite文件
這是tensorflow生成的各種模型文件:
GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions.
CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted.
FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.
SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model.
TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs.
其中關(guān)注的主要三種文件格式:
.pb文件,保存的是圖模型的計算流程圖,包括圖中的常量,但不保存變量,可通過以下兩個方法獲取:
(1):tf.train.write_graph(sess.graph_def,'','graph.pb',as_text=False) #直接保存圖模型,但沒有圖中變量的值
(2):graph = convert_variables_to_constants(sess, sess.graph_def, ["output_image"])
?????tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
#這樣通過將模型里面的所有變量都變?yōu)槌A?那么就可以直接使用.pb文件做成接口,無需.ckpt文件再次導(dǎo)入變量的值.
?
.ckpt文件,保存的是圖模型中的變量的值,要使用.ckpt文件的話,要重構(gòu)圖的結(jié)構(gòu)和初始化圖中變量.可通過以下方式獲取:
saver=tf.train.Saver()
saver.save(sess,"model.ckpt")
?
.lite文件:里面是包含圖模型的計算流程圖和圖模型中的變量的值,可以直接給android系統(tǒng)或者ios系統(tǒng)的tensorflowLite調(diào)用讀取.
?
接下來是生成.lite文件的方法:
首先是,生成lite文件支持的操作和不支持的操作:(如若圖中有不支持的操作,將在生成.lite文件時會報錯)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tf_ops_compatibility.md
?
生成未量化的.lite文件有兩種方式(工具生成跟代碼生成):
第一種,是通過.pb文件和.ckpt文件,進行圖中的變量的固化,再生成.lite文件.執(zhí)行方法如下
(1).要先安裝tensorflow源碼和bazel方法,編譯tensorflow源碼生成tensorflow
(2).cd到源碼目錄下:
(3).編譯生成freeze_graph跟toco工具.最新版本toco替換成tflite_convert工具
??????bazel build tensorflow/python/tools:freeze_graph?
??????bazel build tensorflow/lite/toco:toco?
?
(4).
bazel-bin/tensorflow/python/tools/freeze_graph \
? --input_graph=./model.pb \
? --input_checkpoint=./model.ckpt \
? --output_graph=./frozen_model.pb \
? --input_binary=true \
? --output_node_names=result
#--output_node_names 對應(yīng)的是輸出tensor的name
(5).
./bazel-bin/tensorflow/contrib/lite/toco/toco?
? --input_file=frozen_model.pb \
? --output_file=model.tflite \
? --input_format=TENSORFLOW_GRAPHDEF \
? --output_format=TFLITE \
? --inference_type=FlOAT \
? --input_shape="1,626,361,3" \
? --input_array=input_image \
? --output_array=result \
? --std_value=127.5 \
? --mean_value=127.5 \
? --default_ranges_min=-1.0 \
? --default_ranges_max=1.0 \
? --allow_custom_ops
# --input_arrays 和 --output_arrays 對應(yīng)的是輸入輸出tensor的name
# 注意--input_shapes 必須確定,不可以填None
# --allow_custom_ops 是允許一些傳統(tǒng)方法
可參考:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/convert/cmdline_examples.md?里面包含了混合輸入等多種方法及方法的更新.
?
第二種.是直接在代碼中通過代碼直接生成.lite文件.
如果途中沒有變量
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
? ?tflite_model = tf.lite.toco_convert(sess.graph_def, [img], [out])
? ?open("converteds_model.tflite", "wb").write(tflite_model)
如果圖中有變量的話,需要將變量固化
frozen_graphdef = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output']) ?#這里 ['output']是輸出tensor的名字
tflite_model = tf.lite.toco_convert(frozen_graphdef, [input], [out]) ? #這里[input], [out]這里分別是輸入tensor或者輸出tensor的集合,是變量實體不是名字
open("model.tflite", "wb").write(tflite_model)
?
生成量化的.lite文件也有兩種,分為工具生成跟代碼生成
第一種方法:
(1).首先要想要生成量化的lite文件,在訓(xùn)練graph過程中,就要先偽量化計算圖
在loss后面增加這段代碼:
loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
?
tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay()) #它會自動將計算圖偽量化
(2).在生成pb文件的前面增加一段代碼:
tf.contrib.quantize.create_eval_graph() #增加這段代碼
?
eval_graph_file = 'graph.pb'
?
with open(eval_graph_file, 'w') as f:
?
f.write(str(g.as_graph_def()))
(3).使用freeze_graph將ckpt跟pb文件凍結(jié)成一個圖pb文件
bazel build tensorflow/python/tools:freeze_graph && \
?
bazel-bin/tensorflow/python/tools/freeze_graph \
?
--input_graph=./graph.pb \
?
--input_checkpoint=./model.ckpt-200 \
?
--output_graph=./frozen_eval_graph_test.pb \
?
--output_node_names=result
(4).使用toco工具量化凍結(jié)后的pb文件
./bazel-bin/third_party/tensorflow/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco?
? --input_file=frozen_eval_graph_test.pb \
? --output_file=tflite_model.tflite \
? --input_format=TENSORFLOW_GRAPHDEF?
? --output_format=TFLITE \
? --inference_type=QUANTIZED_UINT8 \
? --input_shape="1,626,361,3" \
? --input_array=input_image \
? --output_array=result \
? --std_value=127.5 --mean_value=127.5 --default_ranges_min=-1.0 --default_ranges_max=1.0
(5).使用python腳本測試調(diào)用生成的lite文件.
import numpy as np
?
import tensorflow as tf
?
import scipy
?
# Load TFLite model and allocate tensors.
?
interpreter = tf.contrib.lite.Interpreter(model_path="tflite_model.tflite")
?
interpreter.allocate_tensors()
?
?
?
# Get input and output tensors.
?
input_details = interpreter.get_input_details()
?
output_details = interpreter.get_output_details()
?
?
?
image=scipy.misc.imread("test.jpg")
?
image_=np.array([image.astype('uint8')])
?
print(image_.shape)
?
print(type(image_))
?
print(input_details)
?
interpreter.set_tensor(input_details[0]['index'], image_)
?
?
?
interpreter.invoke()
?
output_data = interpreter.get_tensor(output_details[0]['index'])
?
scipy.misc.imsave('res.jpg',output_data)
注意事項:
(1).調(diào)用toco時:--inference_type參數(shù)當(dāng)前只支持QUANTIZED_UINT8跟FLOAT,當(dāng)為FLOAT時,生成的lite文件是未量化的,只有在設(shè)置為QUANTIZED_UINT8,生成的lite文件才是量化文件,大小約為未量化的1/4.
(2).生成的量化的lite文件,輸入的Tensor數(shù)據(jù)類型Type必須為uint8.不然會出現(xiàn)傳入類型錯誤.未量化的lite文件可傳入FLOAT類型的Tensor.
(3).調(diào)用toco時.--default_ranges_min= --default_ranges_max=必須傳入
(4).當(dāng)前有些操作不支持量化,當(dāng)出現(xiàn)不支持量化操作的時候,調(diào)用toco工具的時候會出現(xiàn)這個報錯,這時候要不去掉這個操作又不找別的操作代替.
F tensorflow/contrib/lite/toco/graph_transformations/quantize.cc:474] Unimplemented: this graph contains an operator of type Cast for which the quantized form is not yet implemented. Sorry, and patches welcome (that's a relatively fun patch to write, mostly providing the actual quantized arithmetic code for this op).
?
生成量化的.lite文件第二種方式(親測有用):
可以參考google的一個官方例子:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands
主要參考里面的train.py代碼跟freeze.py代碼,分別用于訓(xùn)練跟生成固化的pb文件.
train.py文件的代碼順序如下:
(1),對輸入數(shù)據(jù)進行量化:fingerprint_input = tf.fake_quant_with_min_max_args(input_placeholder, fingerprint_min, fingerprint_max) 我自己訓(xùn)練自己的圖時沒使用這一步.
(2).定義loss
(3).創(chuàng)建量化訓(xùn)練圖:tf.contrib.quantize.create_training_graph(quant_delay=0)
(4).定義optimizer優(yōu)化器
(5).saver = tf.train.Saver(tf.global_variables())
(6).變量初始化
(7).check_point載入
(8).保存pbtxt文件:tf.train.write_graph(sess.graph_def, FLAGS.train_dir,FLAGS.model_architecture + '.pbtxt')
(9).循環(huán)訓(xùn)練
freeze.py文件的代碼順序如下:
(1).創(chuàng)建graph
(2).創(chuàng)建量化eval圖:create_eval_graph()
(3).載入模型的各個變量的參數(shù)
(4).保存pb文件:
input_saver_def = saver.as_saver_def()
?
frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(input_graph_def=tf.get_default_graph().as_graph_def(),input_saver_def=input_saver_def,input_checkpoint = FLAGS.model_file,output_node_names='result',restore_op_name='save/restore_all', filename_tensor_name='save/Const:0',clear_devices=True,output_graph='',initializer_nodes='')
?
binary_graph = 'tflite_graph.pb'
?
with tf.gfile.GFile(binary_graph, 'wb') as f:
?
???????f.write(frozen_graph_def.SerializeToString())
如果使用以下方法生成pb文件,在android上運行時會報錯如下:
#將圖中變量轉(zhuǎn)變成常量:
frozen_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['labels_softmax'])
?
#保存pb文件:
tf.train.write_graph(frozen_graph_def,os.path.dirname(FLAGS.output_file),os.path.basename(FLAGS.output_file),as_text=False)
那么會報以下錯:
Caused by: java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model at org.tensorflow.lite.NativeInterpreterWrapper.createModelWithBuffer(Native Method)
這是因為生成的pb文件不是lite專用的flatbuffer格式.
?
還有,可以通過Netron查看pb文件或者lite文件里面的graph結(jié)構(gòu),可以看里面是否存在FakeQuantWithMinMaxVars操作來確實是否存在量化操作,Netron網(wǎng)頁版如下: https://lutzroeder.github.io/netron/
?
生成后的pb文件通過如下代碼可生成lite文件,分別為兩種方法:
(1).通過以下代碼可以生成完全量化的lite文件,通過Netron可以看到模型里面的變量,計算過程都是uint8.在移動端運行速度會快一些.
import tensorflow as tf
?
import pathlib2 as pathlib
?
?
?
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph('model.pb',["input_image"],["result"], input_shapes={"input_image":[1,626,361,3]}) ? #Python 2.7.6版本,但測試量化后模型大小不會變小
converter = tf.lite.TFLiteConverter.from_frozen_graph('model.pb',["input_image"],["result"], input_shapes={"input_image":[1,626,361,3]}) ? #python3.4.3--nightly版本,測試量化后模型大小會變小
?
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
?
converter.quantized_input_stats = {"input_image" : (127, 2.)}
?
converter.default_ranges_stats=(0, 6)
?
tflite_quantized_model=converter.convert()
?
open("quantized_model.tflite", "wb").write(tflite_quantized_model)
注意:
<1>.其中的quantized_input_stats傳入的參數(shù)為mean均值跟std方差,這兩個值可以通過訓(xùn)練數(shù)據(jù)進行統(tǒng)計獲得.
<2>.其中的default_ranges_stats的作用是對于偽量化后模型中不存在min跟max的激化函數(shù)設(shè)置默認的min跟max.具體的值設(shè)置使用自己估算的activation范圍activation.
?
(2).通過以下代碼可以生成偽量化的lite文件,通過Netron可以看到模型里面的變量,計算過程都還是float.在移動端運行速度會比較慢.
import tensorflow as tf
?
import pathlib2 as pathlib
?
?
?
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph('model.pb',["input_image"],["result"], input_shapes={"input_image":[1,626,361,3]}) ? #Python 2.7.6版本,但測試量化后模型大小不會變小
converter = tf.lite.TFLiteConverter.from_frozen_graph('model.pb',["input_image"],["result"], input_shapes={"input_image":[1,626,361,3]}) ? #python3.4.3--nightly版本,測試量化后模型大小會變小
?
converter.post_training_quantize = True
?
tflite_quantized_model=converter.convert()
?
open("quantized_model.tflite", "wb").write(tflite_quantized_model)
?
以上方法經(jīng)過測試,生成的lite文件是可以正常在android上運行的.
?
?
可參考
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
https://blog.csdn.net/computerme/article/details/80699671
https://tensorflow.juejin.im/performance/quantization.html
?
---------------------?
作者:程序猿也可以很哲學(xué)?
來源:CSDN?
原文:https://blog.csdn.net/qq_16564093/article/details/78996563?
版權(quán)聲明:本文為博主原創(chuàng)文章,轉(zhuǎn)載請附上博文鏈接!
總結(jié)
以上是生活随笔為你收集整理的Tensorflow Lite之编译生成tflite文件的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 从源码透析gRPC调用原理
- 下一篇: 《深入理解java虚拟机》第1章 走近J