谈谈神经网络的大规模训练优化
文 | 立交橋跳水冠軍
源 | 知乎
大規(guī)模神經(jīng)網(wǎng)絡(luò)訓(xùn)練一般會(huì)涉及到幾百個(gè)分布式節(jié)點(diǎn)同時(shí)工作,模型的參數(shù)量以及運(yùn)算量往往很大,作者認(rèn)為在這個(gè)task下當(dāng)前的工作主要?dú)w結(jié)為以下三種:對(duì)通信本身的優(yōu)化,神經(jīng)網(wǎng)絡(luò)訓(xùn)練通信的優(yōu)化,大規(guī)模下如何保持精度。
之前一段時(shí)間接觸了大規(guī)模神經(jīng)網(wǎng)絡(luò)訓(xùn)練,看了不少優(yōu)秀的工作,在這里當(dāng)做筆記記下來(lái)。同時(shí)也希望可以拋磚引玉,和各位大佬交流一下這方面的現(xiàn)有工作以及未來(lái)的方向(1)大規(guī)模訓(xùn)練工作的幾種類型大規(guī)模訓(xùn)練和普通分布式訓(xùn)練還是有區(qū)別的,主要體現(xiàn)在大這個(gè)字上面。一般來(lái)說(shuō)會(huì)涉及到幾百個(gè)分布式節(jié)點(diǎn)同時(shí)工作,模型的參數(shù)量以及運(yùn)算量往往很大(比如BERT,GPT3等等)我認(rèn)為在這個(gè)task下當(dāng)前的工作主要?dú)w結(jié)為以下三種:
對(duì)通信本身的優(yōu)化
神經(jīng)網(wǎng)絡(luò)訓(xùn)練通信的優(yōu)化
大規(guī)模下如何保持精度
其中1主要是通信庫(kù)的優(yōu)化,嚴(yán)格來(lái)說(shuō)和神經(jīng)網(wǎng)絡(luò)本身并沒(méi)有關(guān)系,這里面比較優(yōu)秀的工作有經(jīng)典的ring-base all-reduce(最先在百度的工作中被用于神經(jīng)網(wǎng)絡(luò)訓(xùn)練baidu-research/baidu-allreduce:
https://github.com/baidu-research/baidu-allreduce
騰訊的分層通信:
https://arxiv.org/abs/1807.11205
以及sony的2D all-reduce(Massively Distributed SGD: ImageNet/ResNet-50 Training in a Flash:
https://arxiv.org/abs/1811.05233
而第2部分的工作都針對(duì)于如何在神經(jīng)網(wǎng)絡(luò)這個(gè)訓(xùn)練模式下做通信優(yōu)化。這方面的思路很廣,比如商湯提出的稀疏通信:
https://arxiv.org/abs/1902.06855
杜克大學(xué)提出的TernGrad (TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning:
https://arxiv.org/abs/1705.07878
第三部分和前兩個(gè)不同,主要關(guān)注點(diǎn)在于精度而非性能。在大規(guī)模訓(xùn)練的情況下,一種常見(jiàn)的做法是做數(shù)據(jù)并行,即把batch size設(shè)的很大,那么原來(lái)跑90個(gè)epoch需要迭代1000次的話,把batch size擴(kuò)大10倍,就只需要迭代100次,即參數(shù)的更新次數(shù)減少了很多。如何在這種情況下收斂到小batch size也是一個(gè)棘手的問(wèn)題。在這個(gè)領(lǐng)域比較好的工作有face book的線性倍增學(xué)習(xí)率(https://arxiv.org/pdf/1706.02677.pdf)以及伯克利尤洋的LAR算法(https://arxiv.org/pdf/1709.05011.pdf)。
對(duì)通信本身的優(yōu)化
(懶得寫(xiě)了,偷個(gè)懶)我對(duì)這方面了解十分有限,推薦大家讀騰訊團(tuán)隊(duì)寫(xiě)的介紹(蘭瑞Frank:騰訊機(jī)智團(tuán)隊(duì)分享--AllReduce算法的前世今生:
https://zhuanlan.zhihu.com/p/79030485
神經(jīng)網(wǎng)絡(luò)的通信優(yōu)化
分布式神經(jīng)網(wǎng)絡(luò)訓(xùn)練目前主要有兩種模式:數(shù)據(jù)并行和模型并行。
數(shù)據(jù)并行比較簡(jiǎn)單,下面這張圖是經(jīng)典的數(shù)據(jù)并行的同步訓(xùn)練的場(chǎng)景:所有節(jié)點(diǎn)(即圖中的GPU0-GPU3)都保存整個(gè)模型(粉色的Params),每次迭代,不同的節(jié)點(diǎn)會(huì)得到不同的數(shù)據(jù),每個(gè)節(jié)點(diǎn)用得到的數(shù)據(jù)做正向和反向計(jì)算,得到每個(gè)參數(shù)的梯度。之后整個(gè)分布式系統(tǒng)會(huì)同步所有節(jié)點(diǎn)的梯度,即每個(gè)節(jié)點(diǎn)的local gradient做一次all reduce操作,得到全局的global gradient(最下面藍(lán)色的Gradients)。每個(gè)節(jié)點(diǎn)用這個(gè)global gradient更新參數(shù)。
顯而易見(jiàn),數(shù)據(jù)并行基于一個(gè)假設(shè):每個(gè)節(jié)點(diǎn)都可以放下整個(gè)模型。這個(gè)假設(shè)在如今某些模型上(說(shuō)的就是你,GPT3!!!)是不合理的,因此我們還需要模型并行,即不同節(jié)點(diǎn)負(fù)責(zé)計(jì)算神經(jīng)網(wǎng)絡(luò)模型的不同部分(比如有一個(gè)100層的網(wǎng)絡(luò),那么我們可以讓第一個(gè)節(jié)點(diǎn)存儲(chǔ)前50層的參數(shù),并負(fù)責(zé)計(jì)算前50層,另一個(gè)網(wǎng)絡(luò)則負(fù)責(zé)后面50層)。
下面這張圖摘自英偉達(dá)的Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism:
https://arxiv.org/abs/1909.08053
在這里演示了如何用兩個(gè)節(jié)點(diǎn)去算連續(xù)的兩個(gè)矩陣乘法。
我們要做的操作是首先算出Y=GeLU(XA),再算Z=Dropoug(YB)。其中,X,A,B都是矩陣,而且矩陣規(guī)模都很大。
假設(shè)我們希望用兩個(gè)分布式節(jié)點(diǎn)完成這個(gè)計(jì)算,那么我們可以把矩陣A按colum切成A1,A2兩份,分別存到節(jié)點(diǎn)0和節(jié)點(diǎn)1中。同時(shí)我們也把矩陣B按行切成B1,B2兩份,分別存到節(jié)點(diǎn)0和節(jié)點(diǎn)1中。然后我們將X做一個(gè)broadcast(圖中f部分),分別發(fā)送到兩個(gè)節(jié)點(diǎn)上,算得Z1和Z2,在做一次all reduce(圖中g(shù)部)將Z1和Z2相加,得到最終的Z。
這里面有一個(gè)很巧(也很繞)的地方,那就是為什么A要按列切,B要按行切?我們可不可以把它們反過(guò)來(lái)?答案是:最好不要,因?yàn)槿绻催^(guò)來(lái),的確計(jì)算上可行,但是我們就會(huì)增加一次通信(即算Y=XA的時(shí)候我們就要做一次通信),這樣顯然速度會(huì)變慢。
展開(kāi)來(lái)講,數(shù)據(jù)并行和模型并行也可以細(xì)分。數(shù)據(jù)并行可以分為:
同步式數(shù)據(jù)并行
異步式數(shù)據(jù)并行
同步式比較簡(jiǎn)單,就是我最上面那張圖演示的。
異步式復(fù)雜一些:我們很容易發(fā)現(xiàn),最后全局all reduce gradient的時(shí)候會(huì)耗時(shí)比較多,分布式系統(tǒng)越大,消耗越大,而且這樣做還有一個(gè)隱藏的假設(shè):分布式系統(tǒng)是homogeneous的,即每個(gè)分布式節(jié)點(diǎn)不會(huì)差的很多。舉個(gè)例子,如果每個(gè)節(jié)點(diǎn)實(shí)力相當(dāng),那么都會(huì)算10s就可以結(jié)束一個(gè)iteration,那么我們10s之后就可以開(kāi)始一次通信。然而如果有一個(gè)節(jié)點(diǎn)(害群之馬)需要算100s,那么其他節(jié)點(diǎn)算完之后就得干等它90s才能做通信,那么是對(duì)資源的極大浪費(fèi)。
想想看,你的老板絕對(duì)不允許你(打工人)干坐著什么事都不干,只因?yàn)槟愕倪M(jìn)度被別的同事block了。研究員也是如此,于是為了解決上面的問(wèn)題,引入了異步式通信。簡(jiǎn)單來(lái)說(shuō)就是如果遭遇了上面的情況,快的節(jié)點(diǎn)等一會(huì)兒就不等了,他們之間做一次通信然后接著算下一輪。這個(gè)節(jié)點(diǎn)什么時(shí)候算好什么時(shí)候再和其他人一起all reduce梯度。
這樣做快是快了,但引入了另一個(gè)問(wèn)題,那就是每個(gè)人的參數(shù)都不一樣了,那么他們根據(jù)不同的參數(shù)算得的梯度再去做all reduce就有一些不合理,就會(huì)導(dǎo)致神經(jīng)網(wǎng)絡(luò)精度受損。
有很多工作嘗試解決異步并行帶來(lái)的精度損失,不過(guò)據(jù)我所知并沒(méi)有特別general的方法,因此異步并行如今也很少被使用了。模型并行可以分為:
粗粒度并行
細(xì)粒度并行
它們的區(qū)別在于并行的層級(jí):粗粒度每個(gè)節(jié)點(diǎn)會(huì)算不同的layer,而細(xì)粒度會(huì)將layer也做拆。
分粗粒度并行比較優(yōu)秀的工作有g(shù)oogle的GPipe(https://arxiv.org/pdf/1811.06965.pdf)
在粗粒度并行中,每個(gè)節(jié)點(diǎn)負(fù)責(zé)不同的layer,但是layer之間是存在數(shù)據(jù)依賴的,這就導(dǎo)致在之前的節(jié)點(diǎn)算的時(shí)候,后面的節(jié)點(diǎn)干等著。GPipe提出把數(shù)據(jù)按照batch緯度做切分得到多個(gè)micro batch,這樣第一個(gè)節(jié)點(diǎn)先算第一個(gè)micro batch(圖中F[0,0]),把算到的結(jié)果發(fā)給第二個(gè)節(jié)點(diǎn)去算,于是下一個(gè)時(shí)刻第二個(gè)節(jié)點(diǎn)在算第一個(gè)micro batch(F[1,0]),而第一個(gè)節(jié)點(diǎn)開(kāi)始算第二個(gè)micro batch(F[0,1])。
細(xì)粒度并行比較好的工作除了我之前介紹的Megatron之外,還有GShard(GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding(https://arxiv.org/abs/2006.16668)
這個(gè)工作主要的貢獻(xiàn)在于提供了一套原語(yǔ),允許最高層的開(kāi)發(fā)者(寫(xiě)python的人)通過(guò)簡(jiǎn)單的方式指導(dǎo)代碼生成(即編譯器)生成對(duì)應(yīng)的模型并行的代碼。
后臺(tái)回復(fù)關(guān)鍵詞【入群】
加入賣(mài)萌屋NLP/IR/Rec與求職討論群
后臺(tái)回復(fù)關(guān)鍵詞【頂會(huì)】
獲取ACL、CIKM等各大頂會(huì)論文集!
總結(jié)
以上是生活随笔為你收集整理的谈谈神经网络的大规模训练优化的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 别再Prompt了!谷歌提出tuning
- 下一篇: Linux 程 序 员 失 业 警 告