巨省显存的重计算技巧在TF、Keras中的正确打开方式
一只小狐貍帶你解鎖 煉丹術(shù)&NLP?秘籍
作者:蘇劍林(來自追一科技,人稱“蘇神”)
前言
在前不久的文章《BERT重計(jì)算:用22.5%的訓(xùn)練時間節(jié)省5倍的顯存開銷(附代碼)》中介紹了一個叫做“重計(jì)算”的技巧(附pytorch和paddlepaddle實(shí)現(xiàn))。簡單來說重計(jì)算就是用來省顯存的方法,讓平均訓(xùn)練速度慢一點(diǎn),但batch_size可以增大好幾倍,該技巧首先發(fā)布于論文《Training Deep Nets with Sublinear Memory Cost》。
最近筆者發(fā)現(xiàn),重計(jì)算的技巧在tensorflow也有實(shí)現(xiàn)。事實(shí)上從tensorflow1.8開始,tensorflow就已經(jīng)自帶了該功能了,當(dāng)時被列入了tf.contrib這個子庫中,而從tensorflow1.15開始,它就被內(nèi)置為tensorflow的主函數(shù)之一,那就是tf.recompute_grad。找到?tf.recompute_grad?之后,筆者就琢磨了一下它的用法,經(jīng)過一番折騰,最終居然真的成功地用起來了,居然成功地讓?batch_size?從48增加到了144!然而,在繼續(xù)整理測試的過程中,發(fā)現(xiàn)這玩意居然在tensorflow 2.x是失效的...于是再折騰了兩天,查找了各種資料并反復(fù)調(diào)試,最終算是成功地補(bǔ)充了這一缺陷。
最后是筆者自己的開源實(shí)現(xiàn):
Github地址:
https://github.com/bojone/keras_recompute
該實(shí)現(xiàn)已經(jīng)內(nèi)置在bert4keras中,使用bert4keras的讀者可以升級到最新版本(0.7.5+)來測試該功能。
使用
筆者的實(shí)現(xiàn)也命名為recompute_grad,它是一個裝飾器,用于自定義Keras層的?call函數(shù),比如
from recompute import recompute_gradclass MyLayer(Layer): @recompute_grad def call(self, inputs): return inputs * 2對于已經(jīng)存在的層,可以通過繼承的方式來裝飾:
自定義好層之后,在代碼中嵌入自定義層,然后在執(zhí)行代碼之前,加入環(huán)境變量RECOMPUTE=1來啟用重計(jì)算。
注意:不是在總模型里插入了@recomputr_grad,就能達(dá)到省內(nèi)存的目的,而是要在每個層都插入@recomputr_grad才能更好地省顯存。簡單來說,就是插入的@recomputr_grad越多,就省顯存。具體原因請仔細(xì)理解重計(jì)算的原理。
效果
bert4keras0.7.5已經(jīng)內(nèi)置了重計(jì)算,直接傳入環(huán)境變量RECOMPUTE=1就會啟用重計(jì)算,讀者可以自行嘗試,大概的效果是:
1、在BERT Base版本下,batch_size可以增大為原來的3倍左右;
2、在BERT Large版本下,batch_size可以增大為原來的4倍左右;
3、平均每個樣本的訓(xùn)練時間大約增加25%;
4、理論上,層數(shù)越多,batch_size可以增大的倍數(shù)越大。
環(huán)境
在下面的環(huán)境下測試通過:
tensorflow 1.14 + keras 2.3.1
tensorflow 1.15 + keras 2.3.1
tensorflow 2.0 + keras 2.3.1
tensorflow 2.1 + keras 2.3.1
tensorflow 2.0 + 自帶tf.keras
tensorflow 2.1 + 自帶tf.keras
確認(rèn)不支持的環(huán)境:
tensorflow 1.x + 自帶tf.keras歡迎報(bào)告更多的測試結(jié)果。
順便說一下,強(qiáng)烈建議用keras2.3.1配合tensorflow1.x/2.x來跑,強(qiáng)烈不建議使用tensorflow 2.x自帶的tf.keras來跑。
可
能
喜
歡
算法工程師的效率神器——vim篇
硬核推導(dǎo)Google AdaFactor:一個省顯存的寶藏優(yōu)化器
數(shù)據(jù)缺失、混亂、重復(fù)怎么辦?最全數(shù)據(jù)清洗指南讓你所向披靡
LayerNorm是Transformer的最優(yōu)解嗎?
ACL2020|FastBERT:放飛BERT的推理速度
夕小瑤的賣萌屋
_
關(guān)注&星標(biāo)小夕,帶你解鎖AI秘籍
訂閱號主頁下方「撩一下」有驚喜哦
總結(jié)
以上是生活随笔為你收集整理的巨省显存的重计算技巧在TF、Keras中的正确打开方式的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何做机器学习项目规划?一个事半功倍的c
- 下一篇: 对比学习有多火?文本聚类都被刷爆了…