BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)
一只小狐貍帶你解鎖 煉丹術&NLP?秘籍
作者:夕小瑤、rumor醬
前言
雖然TPU的顯存令人羨慕,但是由于眾所周知的原因,絕大部分人還是很難日常化使用的。英偉達又一直在擠牙膏,至今單卡的最大顯存也僅僅到32G(參考V100、DGX-2)。然而,訓練一個24層的BERT Large模型的時候,如果sequence length開滿512,那么batch size僅僅開到8(有時候能到10)就把這寥寥32G的顯存打滿了。如果想訓練一個48層乃至100層的BERT Large,那完全是土豪們的游戲了,需要瘋狂的模型并行+分布式多機訓練。
但!是!萬能的小夕前不久在Daxiang Dong大佬的安利下,發現了@陳天奇 大佬2016年的一篇寶藏paper!
簡單的劃一下重點:
這篇paper用時間換空間的思想,在前向時只保存部分中間節點,在反向時重新計算沒保存的部分。論文通過這種機制,在每個batch只多計算一次前向的情況下,把n層網絡的占用顯存優化到了。在極端情況下,仍可用的計算時間換取到的顯存占用。在論文的實驗中,他們成功將將1000層的殘差網絡從48G優化到了7G。且,這種方法同樣可以直接應用于RNN結構中。
看完摘要,瞬間感覺在小破卡上訓練BERT Large有救了!!!
此外,來快速過一遍paper中最重要的三點結論:
梯度計算等價,理論上沒有精度損失
可以節省4倍+的顯存開銷
訓練速度僅僅會被拖慢30%
image-20200420140806122不過論文發表在2016年,當時還沒有BERT,不過Baidu Paddle團隊補了一個BERT的實驗結果,發現在BERT上面只用22.5%的訓練速度損失就能換來5倍+的顯存開銷節省!相關實驗在本文末尾,不著急,接下來我們先一起分析一下在訓練階段時顯存為什么容易不足。
感謝Baidu Paddle團隊提供本節圖文素材和測試數據
訓練階段顯存為何不足
深度學習中,網絡的一次訓練包含前向計算、后向計算和優化三個步驟。
在這個過程中,前向計算會輸出大量的隱層變量Tensor,當模型層數加深時,Tensor數量可達成千上萬個。如Bert Large模型,單個Tensor可達到1GB,這些Tensor在顯存中累積,顯存很快就爆掉了╮( ̄▽ ̄"")╭
下圖是Bert Large模型在一次訓練過程中的顯存使用情況,可以明顯看到在前向計算過程中,顯存累積趨勢是一個陡峭的上升直線。而在反向計算過程中,這些隱層Tensor又會很快地被消耗掉,又是一個陡峭的下降曲線,顯存直接降到低位。
那么問題來了,為什么不直接刪除這些前向計算的Tensor呢?
答案很簡單,因為這些隱層的Tensor在反向的時會被用到(手動狗頭
來個簡單的證明。
假設前向計算中有一個矩陣乘法計算:
Y = W × X
對W求梯度:
很容易發現,對W求梯度的公式里有X,而X就是那個巨能吃顯存的隱層Tensor╮( ̄▽ ̄"")╭
那我們是否可以暫時扔掉這些隱層Tensor,在反向計算時再把它們重新生成出來呢?當然可以,這正是上面這篇paper的思想。
重計算
顧名思義,"重計算"就是讓每個訓練迭代過程做兩次前向計算,看起來有點奇怪,實際上卻非常有效!對于剛剛那個吃顯存的Bert Large,支持重計算機制后,顯存占用直接從175GB降低到20GB,陡峭的顯存上升直線變成了緩慢增長的Z形曲線,如下圖所示。
核心思想是將前向計算分割成多個段,將每個段的起始Tensor作為這個段的檢查點(checkpoints)。前向計算時,除了檢查點以外的其他隱層Tensor占有的顯存可以及時釋放。反向計算用到這些隱層Tensor時,從前一個檢查點開始,重新進行這個段的前向計算,就可以重新獲得隱層Tensor。
重計算機制有點像玩單機游戲。每過一個關卡就會保存一個檢查點,而隱層Tensor就相當于游戲中任何一個時刻的圖像。普通的訓練方式是打通一遍游戲,并且將游戲中所有時刻的圖像保存下來;而重計算機制的思路是先把游戲通關,保存檢查點,后面當收到某一時刻圖像的請求時,再重打一遍這一關卡就可以了。
如下圖,舉一個簡單的例子,添加重計算機制前,前向計算中需要存儲的隱層是4個紅點;添加重計算機制后,需要存儲的隱層變為2個藍點, 從而節省了這部分內存。
雖然時間也是寶貴的,但重計算方法的性價比很高。在論文的實驗中,作者用30%的計算時間換取了4倍的內存空間。并且重計算只是重復了一次前向的過程,理論上精度沒有任何損失。
那么這么寶藏的算法有沒有開源實現呢?
開源實現
調研了一波,似乎TF沒有原生支持,但是生態里有第三方實現;pytorch和paddlepaddle中都有原生API支持
Pytorch:
torch.utils.checkpoint
PaddlePaddle:
optimizer.RecomputeOptimizer
不過pytorch的文檔比較略,也沒有提供更細致的示例和相關數據,有興趣的小伙伴自行試一下。paddle框架中提供了詳細到哭的文檔,甚至還有一個現成的BERT+重計算的例子,以及非常詳細的實驗測試結果。這里直接貼過來(真香系列
Paddle中實現顯存重計算大體分為三步:
定義一個經典的優化器,如SGD優化器;
在外面包一層重計算優化器;
設置檢查點。
以MLP為例,只需要增加兩行代碼就可以進入重計算模式
import?paddle.fluid?as?fluid # 定義MLP def mlp(input_x, input_y, hid_dim=128, label_dim=2):print(input_x)fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')cost = fluid.layers.cross_entropy(input=prediction, label=input_y)sum_cost = fluid.layers.reduce_mean(cost)return sum_cost, fc_1, predictioninput_x = fluid.layers.data(name="x", shape=[32], dtype='float32') input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') cost, fc_1, pred = mlp(input_x, input_y)# 定義RecomputeOptimizersgd = fluid.optimizer.SGD(learning_rate=0.01) recompute_optimizer = fluid.optimizer.RecomputeOptimizer(sgd) # 設置checkpoints recompute_optimizer._set_checkpoints([fc_1, pred]) # 運行優化算法 recompute_optimizer.minimize(cost)該示例github鏈接:https://github.com/PaddlePaddle/examples/blob/master/community_examples/recompute/demo.py
此外,官方還給出了一個BERT中做重計算的示例
github鏈接:https://github.com/PaddlePaddle/Fleet/tree/develop/examples/recompute/bert
BERT實驗結論(劃重點
根據上面paddle官方提供的BERT示例和實驗結果,得出以下幾個結論
結論一
在32GB顯存的Tesla V100顯卡上應用重計算機制,可以訓練更大、更深的深度學習模型。當num_tokens為4096(batch size=32,seqlen=128)時,可以訓練100層的Bert網絡。
從Github的實驗結果也可以看出,顯存上的收益比速度的損失要大很多:
在batch_size上提升了5倍,速度只降低了約1/5,且精度沒有損失。
結論二
模型訓練的batch size最大可提升為原來的5倍+,且只有少量的速度損失。
重計算機制在Bert Large這一模型上收益最大,最大batch size從93提升到562!而在VGG-16這種比較淺的模型上,重計算機制的收益則比較小。這充分符合重計算機制的設計理念:為了訓練更大、更深的模型。
結論三
在古董顯卡Tesla K40顯卡(12G顯存)上,訓練BERT Large時batch size可以開到130
最后,希望本文可以幫助大家在小破卡上盡情訓練BERT Large~
夕小瑤的賣萌屋
關注&星標小夕,帶你解鎖AI秘籍
訂閱號主頁下方「撩一下」有驚喜哦
總結
以上是生活随笔為你收集整理的BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何匹配两段文本的语义?
- 下一篇: 做一个好的搜索引擎有多难