将TVM集成到PyTorch上
將TVM集成到PyTorch上
隨著TVM不斷展示出對深度學(xué)習(xí)執(zhí)行效率的改進(jìn),很明顯PyTorch將從直接利用編譯器堆棧中受益。PyTorch的主要宗旨是提供無縫且強(qiáng)大的集成,而這不會妨礙用戶。為此,PyTorch現(xiàn)在具有基于TVM的官方后端torch_tvm。
用法很簡單:
import torch_tvm
torch_tvm.enable()
PyTorch將嘗試在其JIT編譯過程中,將所有可能的運算符轉(zhuǎn)換為已知的Relay運算符。
背景
與許多其他ML框架不同,PyTorch公開了一個渴望執(zhí)行的編程接口。這種編程風(fēng)格避免了基于圖的元編程,而專注于以Python方式直接控制n維數(shù)組(張量)。因此,該框架最初非常適合模型的試驗和開發(fā),但不適用于自動性能優(yōu)化或部署。為了利用優(yōu)化的編譯器技術(shù),PyTorch引入了一些較大的更改來解決此問題。
PyTorch 1.0引入了PyTorch IR,PyTorch專用的中間表示形式,用于類似于Relay的模型。可以通過模型跟蹤將PyTorch程序轉(zhuǎn)換為IR,該跟蹤記錄模型或Python的子集TorchScript的執(zhí)行。新的TVM后端將PyTorch的IR降低到了Relay,并能夠透明地提高PyTorch的性能,而無需用戶參與。
整合與結(jié)果
為了支持Relay,PyTorch JIT添加了兩個功能:自定義轉(zhuǎn)換過程和自定義子圖解釋器。
當(dāng)torch_tvm啟用時,可以轉(zhuǎn)換到中繼PyTorch IR的子圖Expr旨意被標(biāo)記為繼電器兼容。由于PyTorch IR并不總是包含形狀信息,因此在調(diào)用之前,無法以有用的方式編譯任何子圖。
在用戶調(diào)用期間,PyTorch JIT運行時將確定輸入形狀信息,并使用新的Relay C ++構(gòu)建系統(tǒng)編譯先前標(biāo)記的子圖。根據(jù)輸入形狀來緩存編譯,以供后續(xù)運行。可以在README中找到更多詳細(xì)信息。
torch_tvm建立了一個連續(xù)的基準(zhǔn)測試系統(tǒng),該系統(tǒng)正在監(jiān)視ResNet18在CPU上的性能。對于各種ResNet型號,TVM的性能都是默認(rèn)PyTorch JIT后端的兩倍以上。在AWS c5n.4xlarge實例上使用16個線程實現(xiàn)的每秒迭代次數(shù)(越大越好)。
這些結(jié)果令人鼓舞,該項目將繼續(xù)致力于,在更多模型上提高CPU推理速度。
未來的工作
現(xiàn)在,PyTorch JIT進(jìn)行了大量工作來查找其IR的純功能子集,以饋送到Relay。這避免了將別名和控制流信息映射到中繼的需要,但這不是必需的。將更多的PyTorch IR映射到Relay可能會取得性能上的勝利,這是該項目的目標(biāo)。PyTorch IR在開發(fā)過程中正在迅速變化,因此必須謹(jǐn)慎進(jìn)行。
將做更多的工作來確保PyTorch和TVM代碼之間的切換是有效的。這包括統(tǒng)一線程模型,分配器以及減少與將輸入復(fù)制到TVM相關(guān)的開銷。
解析
如果已經(jīng)編寫了PyTorch模型,最簡單的入門方法就是使用torch.jit.trace以下方法
import torch_tvm
from your_model import model, inputs
torch_tvm.enable(opt_level=3)
iters = 100
warmup = 10
Ensure your model is in eval mode and also turn off gradients.
with torch.no_grad():
Use tuned parameters for better performance.
with autotvm.apply_history_best(“test/autotvm_tuning.log”):
# This is where all the compilation happens.
trace_tvm = torch.jit.trace(model, inputs)
# Warmup
for _ in range(warmup):_ = trace_tvm(*inputs)# Benchmark
start = time.time()
for _ in range(iters):_ = trace_tvm(*inputs)
tvm_time = time.time() - startprint("Took {}s to run {} iters".format(tvm_time, iters))
這段代碼大部分來自Benchmarks.py。請注意,用于AVX2 LLVM編譯的調(diào)整參數(shù)位于存儲庫test/文件夾中。
如果更直接使用Relay,可以通過(隱式)跟蹤或TorchScript,直接從PyTorch函數(shù)中提取表達(dá)式:
def add(a, b, c):
return a + b + c
via tracing
relay_graph = torch_tvm.to_relay(add, inputs)
@torch.jit.script
def mul(a, b, c):
return a * b * c
via script
relay_graph = torch_tvm.to_relay(mul, inputs)
總結(jié)
以上是生活随笔為你收集整理的将TVM集成到PyTorch上的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 自定义Kubernetes调度程序来编排
- 下一篇: 自动生成低精度深度学习算子