tensorflow打印模型图_从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)...
最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎(chǔ),通俗易懂,而且非常風(fēng)趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當(dāng)小說看,【點(diǎn)這里可以去膜拜一下大神的“小說”】。
Tensorflow官方提供的Tensorboard可以可視化神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)圖,但是說實(shí)話,我?guī)缀鯊膩聿挥谩V饕且驗(yàn)門ensorboard中查看到的圖結(jié)構(gòu)太混亂了,包含了網(wǎng)絡(luò)中所有的計(jì)算節(jié)點(diǎn)(讀取數(shù)據(jù)節(jié)點(diǎn)、網(wǎng)絡(luò)節(jié)點(diǎn)、loss計(jì)算節(jié)點(diǎn)等等)。更可怕的是,如果一個計(jì)算節(jié)點(diǎn)是由多個基礎(chǔ)計(jì)算(如加減乘除等)構(gòu)成,那么在Tensorboard中會將基礎(chǔ)計(jì)算節(jié)點(diǎn)顯示而不是作為一個整體顯示(典型的如Squeeze計(jì)算節(jié)點(diǎn))。最近為了排查網(wǎng)絡(luò)結(jié)構(gòu)BUG花費(fèi)一周時間,因此,狠下心來決定自己寫一個工具,將Tensorflow中的圖以最簡單的方式顯示最關(guān)鍵的網(wǎng)絡(luò)結(jié)構(gòu)。
1 Tensor對象與Operation對象
Tensorflow中,Tensor對象主要用于存儲數(shù)據(jù)如常量和變量(訓(xùn)練參數(shù)),Operation對象是計(jì)算節(jié)點(diǎn),如卷積計(jì)算、反卷積計(jì)算、ReLU等等。每一個Operation對象均有輸入和輸出Tensor,同理,每個Tensor對象均有對應(yīng)生成該Tensor的Operation對象和使用該Tensor對象作為輸入的Operation對象。Tensor和Operation對象內(nèi)均有相關(guān)屬性和函數(shù)來獲取其關(guān)聯(lián)的Operation和Tensor對象,相關(guān)屬性如下所示。
Tensor對象的op屬性指向生成該Tensor的Operation對象。
Tensor對象的consumers()函數(shù)獲取使用該Tensor對象作為輸入的Operation對象。
Operation對象的inputs屬性指向該計(jì)算節(jié)點(diǎn)的輸入Tensor對象。
Operation對象的outputs屬性執(zhí)行該計(jì)算節(jié)點(diǎn)的輸出Tensor對象。
如下圖所示的網(wǎng)絡(luò)結(jié)構(gòu)中,調(diào)用Tensor_2對象的consumers()函數(shù),返回的是[op_1,op_2]。Tensor_3的op屬性指向的是op_1。op_1的inputs屬性指向的是[Tensor_1,Tensor_2],op_1的output屬性指向的是[Tensor_3]。
Tensor與Operation
有了Tensor與Operation對應(yīng)在圖中的關(guān)聯(lián)關(guān)系,就可以將網(wǎng)絡(luò)結(jié)構(gòu)給畫出來。
2 提取pb文件中的網(wǎng)絡(luò)結(jié)構(gòu)圖
pb文件是將模型參數(shù)固化到圖文件中,并合并了一些基礎(chǔ)計(jì)算和刪除了反向傳播相關(guān)計(jì)算得到的protobuf協(xié)議文件。如果讀者還不懂如何將CKPT模型文件轉(zhuǎn)pb文件,請參考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1節(jié)部分。有了pb模型文件后,接下來是加載模型,加載pb模型示例代碼如下所示。
def read_graph_from_pb(tf_model_path ,input_names,output_name):
with open(tf_model_path, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
gdef = tf.GraphDef()
gdef.ParseFromString(serialized)
with tf.Graph().as_default() as g:
tf.import_graph_def(gdef, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
其中,倒數(shù)第2行調(diào)用到的函數(shù)get_ops_from_pb()用于獲取網(wǎng)絡(luò)結(jié)構(gòu)圖中指定輸入節(jié)點(diǎn)和指定輸出節(jié)點(diǎn)之間的計(jì)算節(jié)點(diǎn)。之所以要指定輸入和輸出,是為了將輸入之前的計(jì)算節(jié)點(diǎn)(如加載數(shù)據(jù)隊(duì)列等相關(guān)計(jì)算節(jié)點(diǎn))和輸出之后的計(jì)算節(jié)點(diǎn)(如計(jì)算loss等相關(guān)計(jì)算節(jié)點(diǎn))去除,免得礙眼。函數(shù)get_ops_from_pb()實(shí)現(xiàn)代碼如下。
def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
if save_ori_network:
with open('ori_network.txt','w+') as w:
OPS=graph.get_operations()
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
output_tf =graph.get_tensor_by_name(output_name)
OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] )
with open('network.txt','w+') as w:
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
OPS = sort_ops(OPS)
OPS = merge_layers(OPS)
return OPS
在裁剪網(wǎng)絡(luò)結(jié)構(gòu)(即只保留input_names和output_name之間節(jié)點(diǎn))之前,先將原始的網(wǎng)絡(luò)結(jié)構(gòu)寫入到ori_network.txt中,文件中,每一行寫入:輸入Tensor---->op---->輸出Tensor。接下來調(diào)用函數(shù)get_ops_from_inputs_outputs獲取指定節(jié)點(diǎn)之間的節(jié)點(diǎn)。并調(diào)用sort_ops函數(shù)對所有的節(jié)點(diǎn)排序,以保證被依賴的節(jié)點(diǎn)總是出現(xiàn)在相關(guān)節(jié)點(diǎn)之前。最后調(diào)用merge_layers函數(shù),將一些可以合并的計(jì)算合并成一個獨(dú)立的節(jié)點(diǎn),例如,Squeeze計(jì)算相關(guān)節(jié)點(diǎn)合并成一個單獨(dú)的Squeeze節(jié)點(diǎn),又如const-->identity兩個計(jì)算節(jié)點(diǎn)可以直接忽略(即刪除)。
注意:篇幅有限,這里不再將函數(shù)get_ops_from_inputs_outputs、sort_ops、merge_layers貼出,相關(guān)代碼請前往文尾提供的源碼地址中閱讀。
3 繪制網(wǎng)絡(luò)結(jié)構(gòu)
考慮到SVG繪制圖形的簡單易用優(yōu)點(diǎn),將排好序的網(wǎng)絡(luò)計(jì)算節(jié)點(diǎn)和相關(guān)Tensor對象數(shù)據(jù)以Javascript字符串的形式寫入到HTML中,使用標(biāo)簽繪制箭頭,使用標(biāo)簽繪制矩形,使用標(biāo)簽繪制橢圓,使用標(biāo)簽顯示文字。繪制類似于如下所示圖像
繪制網(wǎng)絡(luò)結(jié)構(gòu)示例
注意:篇幅有限,這里不再介紹Javascript代碼解析模型結(jié)構(gòu)和SVG顯示相關(guān)的原理,相關(guān)代碼請前往文尾提供的源碼地址中閱讀。
4 測試模型顯示
以《MobileNet V1官方預(yù)訓(xùn)練模型的使用》文中介紹的MobileNet V1網(wǎng)絡(luò)結(jié)構(gòu)為例,下載MobileNet_v1_1.0_192文件并壓縮后,得到mobilenet_v1_1.0_192_frozen.pb文件。我們還需要知道m(xù)obilenet_v1_1.0_192_frozen.pb模型對應(yīng)的輸入和輸出Tensor對象的名稱,好在MobileNet_v1_1.0_192壓縮包中包含文件mobilenet_v1_1.0_192_info.txt。通過該文件可知,輸入Tensor的名稱為:input:0,輸出Tensor名稱為:MobilenetV1/Predictions/Reshape_1:0。有了這些信息后,調(diào)用函數(shù)read_graph_from_pb得到靜態(tài)圖的節(jié)點(diǎn)列表對象ops,調(diào)用函數(shù)gen_graph(ops,"save/path/graph.html")后,在目錄save/path中得到graph.html文件,打開graph.html后,顯示結(jié)果如下。
顯示網(wǎng)絡(luò)結(jié)構(gòu)分兩種模式:合并模式和展開模式,分別如下圖所示。
合并模式網(wǎng)絡(luò)結(jié)構(gòu)
截取的展開模式網(wǎng)絡(luò)結(jié)構(gòu)
5 源碼地址
總結(jié)
以上是生活随笔為你收集整理的tensorflow打印模型图_从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)...的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 美股周五:三大股指连续4周收涨,英伟达跌
- 下一篇: 存储行业新突破:Cerabyte 展示陶