Tensorflow源码解析3 -- TensorFlow核心对象 - Graph
1 Graph概述
計算圖Graph是TensorFlow的核心對象,TensorFlow的運行流程基本都是圍繞它進行的。包括圖的構建、傳遞、剪枝、按worker分裂、按設備二次分裂、執行、注銷等。因此理解計算圖Graph對掌握TensorFlow運行尤為關鍵。
2 默認Graph
默認圖替換
之前講解Session的時候就說過,一個Session只能run一個Graph,但一個Graph可以運行在多個Session中。常見情況是,session會運行全局唯一的隱式的默認的Graph,operation也是注冊到這個Graph中。
也可以顯示創建Graph,并調用as_default()使他替換默認Graph。在該上下文管理器中創建的op都會注冊到這個graph中。退出上下文管理器后,則恢復原來的默認graph。一般情況下,我們不用顯式創建Graph,使用系統創建的那個默認Graph即可。
print tf.get_default_graph()with tf.Graph().as_default() as g:print tf.get_default_graph() is gprint tf.get_default_graph()print tf.get_default_graph()輸出如下
<tensorflow.python.framework.ops.Graph object at 0x106329fd0> True <tensorflow.python.framework.ops.Graph object at 0x18205cc0d0> <tensorflow.python.framework.ops.Graph object at 0x10d025fd0>由此可見,在上下文管理器中,當前線程的默認圖被替換了,而退出上下文管理后,則恢復為了原來的默認圖。
默認圖管理
默認graph和默認session一樣,也是線程作用域的。當前線程中,永遠都有且僅有一個graph為默認圖。TensorFlow同樣通過棧來管理線程的默認graph。
@tf_export("Graph") class Graph(object):# 替換線程默認圖def as_default(self):return _default_graph_stack.get_controller(self)# 棧式管理,push pop@tf_contextlib.contextmanagerdef get_controller(self, default):try:context.context_stack.push(default.building_function, default.as_default)finally:context.context_stack.pop()替換默認圖采用了堆棧的管理方式,通過push pop操作進行管理。獲取默認圖的操作如下,通過默認graph棧_default_graph_stack來獲取。
@tf_export("get_default_graph") def get_default_graph():return _default_graph_stack.get_default()下面來看_default_graph_stack的創建
_default_graph_stack = _DefaultGraphStack() class _DefaultGraphStack(_DefaultStack): def __init__(self):# 調用父類來創建super(_DefaultGraphStack, self).__init__()self._global_default_graph = Noneclass _DefaultStack(threading.local):def __init__(self):super(_DefaultStack, self).__init__()self._enforce_nesting = True# 和默認session棧一樣,本質上也是一個listself.stack = []_default_graph_stack的創建如上所示,最終和默認session棧一樣,本質上也是一個list。
3 前端Graph數據結構
Graph數據結構
理解一個對象,先從它的數據結構開始。我們先來看Python前端中,Graph的數據結構。Graph主要的成員變量是Operation和Tensor。Operation是Graph的節點,它代表了運算算子。Tensor是Graph的邊,它代表了運算數據。
@tf_export("Graph") class Graph(object):def __init__(self):# 加線程鎖,使得注冊op時,不會有其他線程注冊op到graph中,從而保證共享graph是線程安全的self._lock = threading.Lock()# op相關數據。# 為graph的每個op分配一個id,通過id可以快速索引到相關op。故創建了_nodes_by_id字典self._nodes_by_id = dict() # GUARDED_BY(self._lock)self._next_id_counter = 0 # GUARDED_BY(self._lock)# 同時也可以通過name來快速索引op,故創建了_nodes_by_name字典self._nodes_by_name = dict() # GUARDED_BY(self._lock)self._version = 0 # GUARDED_BY(self._lock)# tensor相關數據。# 處理tensor的placeholderself._handle_feeders = {}# 處理tensor的read操作self._handle_readers = {}# 處理tensor的move操作self._handle_movers = {}# 處理tensor的delete操作self._handle_deleters = {}下面看graph如何添加op的,以及保證線程安全的。
def _add_op(self, op):# graph被設置為final后,就是只讀的了,不能添加op了。self._check_not_finalized()# 保證共享graph的線程安全with self._lock:# 將op以id和name分別構建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便后續快速索引self._nodes_by_id[op._id] = opself._nodes_by_name[op.name] = opself._version = max(self._version, op._id)GraphKeys 圖分組
每個Operation節點都有一個特定的標簽,從而實現節點的分類。相同標簽的節點歸為一類,放到同一個Collection中。標簽是一個唯一的GraphKey,GraphKey被定義在類GraphKeys中,如下
@tf_export("GraphKeys") class GraphKeys(object):GLOBAL_VARIABLES = "variables"QUEUE_RUNNERS = "queue_runners"SAVERS = "savers"WEIGHTS = "weights"BIASES = "biases"ACTIVATIONS = "activations"UPDATE_OPS = "update_ops"LOSSES = "losses"TRAIN_OP = "train_op"# 省略其他name_scope 節點命名空間
使用name_scope對graph中的節點進行層次化管理,上下層之間通過斜杠分隔。
# graph節點命名空間 g = tf.get_default_graph() with g.name_scope("scope1"):c = tf.constant("hello, world", name="c")print c.op.namewith g.name_scope("scope2"):c = tf.constant("hello, world", name="c")print c.op.name輸出如下
scope1/c scope1/scope2/c # 內層的scope會繼承外層的,類似于棧,形成層次化管理4 后端Graph數據結構
Graph
先來看graph.h文件中的Graph類的定義,只看關鍵代碼
class Graph {private:// 所有已知的op計算函數的注冊表FunctionLibraryDefinition ops_;// GraphDef版本號const std::unique_ptr<VersionDef> versions_;// 節點node列表,通過id來訪問std::vector<Node*> nodes_;// node個數int64 num_nodes_ = 0;// 邊edge列表,通過id來訪問std::vector<Edge*> edges_;// graph中非空edge的數目int num_edges_ = 0;// 已分配了內存,但還沒使用的node和edgestd::vector<Node*> free_nodes_;std::vector<Edge*> free_edges_;}后端中的Graph主要成員也是節點node和邊edge。節點node為計算算子Operation,邊為算子所需要的數據,或者代表節點間的依賴關系。這一點和Python中的定義相似。邊Edge的持有它的源節點和目標節點的指針,從而將兩個節點連接起來。下面看Edge類的定義。
Edge
class Edge {private:Edge() {}friend class EdgeSetTest;friend class Graph;// 源節點, 邊的數據就來源于源節點的計算。源節點是邊的生產者Node* src_;// 目標節點,邊的數據提供給目標節點進行計算。目標節點是邊的消費者Node* dst_;// 邊id,也就是邊的標識符int id_;// 表示當前邊為源節點的第src_output_條邊。源節點可能會有多條輸出邊int src_output_;// 表示當前邊為目標節點的第dst_input_條邊。目標節點可能會有多條輸入邊。int dst_input_; };Edge既可以承載tensor數據,提供給節點Operation進行運算,也可以用來表示節點之間有依賴關系。對于表示節點依賴的邊,其src_output_, dst_input_均為-1,此時邊不承載任何數據。
下面來看Node類的定義。
Node
class Node {public:// NodeDef,節點算子Operation的信息,比如op分配到哪個設備上了,op的名字等,運行時有可能變化。const NodeDef& def() const;// OpDef, 節點算子Operation的元數據,不會變的。比如Operation的入參列表,出參列表等const OpDef& op_def() const;private:// 輸入邊,傳遞數據給節點。可能有多條EdgeSet in_edges_;// 輸出邊,節點計算后得到的數據??赡苡卸鄺lEdgeSet out_edges_; }節點Node中包含的主要數據有輸入邊和輸出邊的集合,從而能夠由Node找到跟他關聯的所有邊。Node中還包含NodeDef和OpDef兩個成員。NodeDef表示節點算子的信息,運行時可能會變,創建Node時會new一個NodeDef對象。OpDef表示節點算子的元信息,運行時不會變,創建Node時不需要new OpDef,只需要從OpDef倉庫中取出即可。因為元信息是確定的,比如Operation的入參個數等。
由Node和Edge,即可以組成圖Graph,通過任何節點和任何邊,都可以遍歷完整圖。Graph執行計算時,按照拓撲結構,依次執行每個Node的op計算,最終即可得到輸出結果。入度為0的節點,也就是依賴數據已經準備好的節點,可以并發執行,從而提高運行效率。
系統中存在默認的Graph,初始化Graph時,會添加一個Source節點和Sink節點。Source表示Graph的起始節點,Sink為終止節點。Source的id為0,Sink的id為1,其他節點id均大于1.
5 Graph運行時生命周期
Graph是TensorFlow的核心對象,TensorFlow的運行均是圍繞Graph進行的。運行時Graph大致經過了以下階段
這些階段根據TensorFlow運行時的不同,會進行不同的處理。運行時有兩種,本地運行時和分布式運行時。故Graph生命周期到后面分析本地運行時和分布式運行時的時候,再詳細講解。
本文作者:揚易
閱讀原文
本文為云棲社區原創內容,未經允許不得轉載。
總結
以上是生活随笔為你收集整理的Tensorflow源码解析3 -- TensorFlow核心对象 - Graph的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 带上一份技能地图
- 下一篇: 归纳DOM事件中各种阻止方法