tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...
地址:http://bindog.github.io/
01
背景
前幾天看到知乎上的文章FLOPs與模型推理速度[1],文中提到一個比較耗時又占顯存的pointwise操作x * sigmoid(x),這實際上是swish activation[2];暫且不提它背后的爭議,本文主要想從這個結構入手來優(yōu)化它的顯存占用以及耗時,并討論更廣泛的訓練時顯存優(yōu)化技術。02
反向傳播是如何工作的?
要分析清楚swish activation為什么會比較占顯存,我們首先需要搞清楚反向傳播是如何工作的,或者更進一步說,現(xiàn)有的自動求導框架是如何求出梯度的。先明確一點,所謂自動求導框架實際上是“半自動”的:它并非直接求出一個復雜函數(shù)導數(shù)的解析形式,而是通過構建計算圖和預先寫好的基礎函數(shù)的求導規(guī)則,結合鏈式求導法則實現(xiàn)的自動求導。以swish acivation為例進行說明,其表達式為f(x) = x * sigmoid(x),通過簡單的數(shù)學推導得到其梯度的解析式為f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把這個結果放一邊,看看自動求導框架是如何一步步求出這個結果的,畫出計算圖如下:除了計算圖以外,我們還需要定義幾個基本函數(shù)的求導規(guī)則,在這個例子里涉及兩個函數(shù),一個是乘法,另一個是sigmoid函數(shù)(實際上sigmoid也是由幾個基本函數(shù)構成的,這里我們將其視為一個整體)f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x) # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))03
顯存被誰吃掉了
先說一個結論,在絕大多數(shù)神經(jīng)網(wǎng)絡的訓練過程中,顯存占用的大頭是中間結果,也就是所謂的“特征圖”。那我們?yōu)槭裁匆A糁虚g結果呢?當然是為了方便求導啊!還是以swish acivation為例,把它放入神經(jīng)網(wǎng)絡來看,x就是前一層輸出的中間結果(特征圖)- 在適用乘法的求導規(guī)則時,要求我們要事先保留下中間結果x和sigmoid(x),有人可能會說只保留一個x不就可以了嗎?sigmoid(x)可以通過計算得出,注意框架定義的乘法及其求導規(guī)則是通用規(guī)則,乘法的左右兩邊完全可能是不相關的兩個值,所以必須同時保留下來。
- 在對sigmoid函數(shù)適用求導規(guī)則時,需要存下中間結果x。
04
手動合并OP
那么有沒有辦法優(yōu)化呢?當然是可以的,既然我們能用數(shù)學公式提前算出swish acivation的梯度,那么直接將其視為一個整體不就好了?無非就是定義一個新的函數(shù)和新的求導規(guī)則
swish(x) = x * sigmoid(x)# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))這樣一來,計算圖變成了下面這個樣子:
x的梯度可以直接根據(jù)新的規(guī)則求出,而在新的規(guī)則下,我們只需要保留x這一個中間結果即可,sigmoid(x)可以根據(jù)x求出。所以說,自動求導框架雖然省事,但是其缺陷也很明顯,由于大部分求導規(guī)則是面向通用的函數(shù),很難針對特定的場景進行自動優(yōu)化而導致顯存浪費。對swish acivation這樣的函數(shù),只能依靠工程師的經(jīng)驗手動的進行優(yōu)化。需要指出的是,現(xiàn)有的一些框架如TVM和TensorRT也能自動的對某些算子進行融合,進而大大提高計算效率,降低顯存消耗,但是這些都屬于部署階段了,而本文討論的均為訓練階段。類似的優(yōu)化案例還有inplace-abn[3],針對的是類似BN-ReLU-Conv這樣的常見結構組合,如下所示圖中的虛線框是需要保留的中間結果,inplace-abn的優(yōu)化思路是只保留中間結果z,通過反推得到x,然而眾所周知ReLU是不可逆的運算,因此inplace-abn將其替換為了Leaky ReLU,計算圖變成了如下形式:接下來的事情就是用數(shù)學的方式手動求出導數(shù),然后定義成規(guī)則即可。對II型,更進一步,直接用的反函數(shù),進行替換即可
雖然推導過程有些復雜,但寫出求導公式后,我們只需要將其封裝進手寫的模塊中即可。原論文[4]中的實現(xiàn)表明,采用Inplace-abn后,顯存占用最高可下降50%左右,而且由于Leaky ReLU實際效果其實與ReLU非常接近,省下來的顯存可以用于提高batch_size,模型訓練實際上能從中得到更大收益。
05
還能更進一步嗎?
回想前面的優(yōu)化過程,我們發(fā)現(xiàn)其實這是一種典型的時間換空間的做法,雖然模型占用的顯存下降了(舍棄了大量中間結果),但是我們定義的求導規(guī)則非常復雜,計算步驟明顯多于優(yōu)化前,其根本原因并非是不需要中間結果,而是有辦法在求導過程中實時的計算出之前被舍棄掉的中間結果。考慮GPU上顯存資源與計算資源的關系,只用較少的計算量和額外的一點計算時間換取寶貴的顯存資源,這么做實際上是劃算的。如果沿著這個思路更進一步,所有的中間結果都不需要存儲了,只需要存最初的輸入即可,因為所有的中間結果都可以由輸入重新計算得到,然而這個方案顯然是不劃算的,因為反向傳播的過程是“由深入淺”,而計算中間結果的過程是“由淺入深”,二者的方向并不匹配,每當我們需要中間結果時就需要從頭再來一遍,這樣的計算和時間開銷顯然是不劃算的。如果折中一下呢?這就是OpenAI提出的gradient-checkpoint的思路,在神經(jīng)網(wǎng)絡中間設置若干個檢查點(checkpoint),檢查點以外的中間結果全部舍棄,反向傳播求導數(shù)的時間,需要某個中間結果時,從最近的檢查點開始計算,這樣既節(jié)省了顯存,又避免了從頭計算的繁瑣過程;從代碼層面來看,原版實現(xiàn)[5]用的是tensorflow,由于是靜態(tài)圖的緣故,需要用到grapheditor等一系列騷操作,而且包含了很多“智能”尋找bottleneck選擇為checkpoint的代碼,很容易勸退新人。但是如果看一下pytorch的官方實現(xiàn)[6],你會驚訝的發(fā)現(xiàn)gradient-checkpoint的核心部分出奇的簡單,這也算是動態(tài)圖以及pytorch的一點小優(yōu)勢吧,當然pytorch版本的實現(xiàn)并不包括智能尋找checkpoint點的功能,需要人為設定。核心代碼如下所示:class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") inputs = ctx.saved_tensors # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None, None) + grads注意到最近曠視開源的MegEngine,在PR的時候提到一個亞線性顯存優(yōu)化技術[7],其實就是gradient-checkpoint技術,詳情可參考論文Training Deep Nets with Sublinear Memory Cost[8],當然MegEngine肯定在細節(jié)上對其進行了一些優(yōu)化,本文就不展開討論了。06
CUDA版的swish activation
回到swish activation的優(yōu)化上來,如果要追求效率的極致提升,下一步考慮的方案應該是手寫C++ extension,將計算從python層面轉移到C++與CUDA上。如何基于 pytorch寫C++擴展,官方文檔上有非常詳細的教程[9],寫法和方式也都比較靈活,可以根據(jù)自己的習慣進行選擇,這里我們選擇利用setuptools的方式進行構建pytorch在用戶自定義擴展上也是做了非常多的支持,用戶能非常方便的使用pytorch底層定義好的一些類和函數(shù);在寫CUDA函數(shù)時,pytorch還提供了一個CUDAApplyUtils.cuh頭文件,專門用于優(yōu)化pointwise操作的情況,以減小拷貝和臨時存儲的顯存浪費(用于lambda函數(shù),函數(shù)名非常直觀,CUDA_tensor_applyN表示操作數(shù)的個數(shù),N可以為1,2,3,4,用戶還可以指定每個操作數(shù)的屬性,如只讀/讀寫,針對每對情形都有專門的優(yōu)化實現(xiàn))對于swish activation來說,由于全是pointwise操作,利用這個優(yōu)化技巧可以把顯存占用進一步壓縮。具體代碼可參考swish_optimize[10]簡單對比一下以上幾種實現(xiàn)在實際場景中(單卡RTX 2070,resnet50, bs=32)的顯存占用情況和運行時間(一次forward & 一次backward & 參數(shù)更新)- 無優(yōu)化純Python:GPU memory=6383MB,time=223ms
- 合并算子(Python):GPU memory=5139MB,time=234ms
- 合并算子(CUDA):GPU memory=5143MB,time=188ms
[1] https://zhuanlan.zhihu.com/p/122943688
[2] https://arxiv.org/abs/1710.05941
[3] https://github.com/mapillary/inplace_abn
[4] https://arxiv.org/pdf/1712.02616.pdf
[5] https://github.com/cybertronai/gradient-checkpointing
[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55
[7] https://zhuanlan.zhihu.com/p/138730559
[8] https://arxiv.org/abs/1604.06174
[9] https://pytorch.org/tutorials/advanced/cpp_extension.html
[10] https://github.com/bindog/swish_optimize
本文目的在于學術交流,并不代表本公眾號贊同其觀點或對其內(nèi)容真實性負責,版權歸原作者所有,如有侵權請告知刪除。
直播預告
歷史文章推薦
【CVPR 2020 Tutorial】如何寫好論文和評審(概述)
如何撰寫高水平的博士論文?超全論文指導!
北大讀博手記:怎樣完成自己的博士生涯?非常具有指導性!
太牛逼了!一位中國博士把整個CNN都給可視化了,每個細節(jié)看的清清楚楚!
Nature發(fā)表牛津博士建議:我希望在讀博士之初時就能知道的20件事
沈向洋、華剛:讀科研論文的三個層次、四個階段與十個問題
如何看待2021年秋招算法崗灰飛煙滅?
獨家解讀 | ExprGAN:基于強度可控的表情編輯
獨家解讀 | 矩陣視角下的BP算法
獨家解讀 | Capsule Network深度解讀
獨家解讀 | Fisher信息度量下的對抗攻擊
論文解讀 | 知識圖譜最新研究綜述
你的畢業(yè)論文過了嗎?《如何撰寫畢業(yè)論文?》
卡爾曼濾波系列——經(jīng)典卡爾曼濾波推導
分享、點贊、在看,給個三連擊唄!
總結
以上是生活随笔為你收集整理的tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python收入波动告警分析_使用Pyt
- 下一篇: 标题隐藏_头条官方课程没看就想起好标题?