ICLR 2021 | 显存不够?不妨抛弃端到端训练
?作者|王語霖
學(xué)校|清華大學(xué)自動化系博士生
研究方向|機(jī)器學(xué)習(xí)和計算機(jī)視覺
本文主要介紹我們被 ICLR 2021 接收的一篇文章,代碼已經(jīng)在?Github 上面開源。
論文標(biāo)題:
Revisiting Locally Supervised Learning: an Alternative to End-to-end Training
論文鏈接:
https://openreview.net/forum?id=fAbkE6ant2
代碼鏈接:
https://github.com/blackfeather-wang/InfoPro-Pytorch
(太長不看版)本文研究了一種比目前廣為使用的端到端訓(xùn)練模式顯存開銷更小、更容易并行化的訓(xùn)練方法:將網(wǎng)絡(luò)拆分成若干段、使用局部監(jiān)督信號進(jìn)行訓(xùn)練。我們指出了這一范式的一大缺陷在于損失網(wǎng)絡(luò)整體性能,并從信息的角度闡明了,其癥結(jié)在于局部監(jiān)督傾向于使網(wǎng)絡(luò)在淺層損失對深層網(wǎng)絡(luò)有很大價值的任務(wù)相關(guān)信息。
為有效解決這一問題,我們提出了一種局部監(jiān)督學(xué)習(xí)算法:InfoPro。在圖像識別和語義分割任務(wù)上的實驗結(jié)果表明,我們的算法可以在不顯著增大訓(xùn)練時間的前提下,有效節(jié)省顯存開銷,并提升性能。
研究動機(jī)及簡介
一般而言,深度神經(jīng)網(wǎng)絡(luò)以端到端的形式訓(xùn)練。以一個 13 層的簡單卷積神經(jīng)網(wǎng)絡(luò)為例,我們會將訓(xùn)練數(shù)據(jù)輸入網(wǎng)絡(luò)中,逐層前傳至最后一層,輸出結(jié)果,計算損失值(End-to-End Loss),再從損失求得梯度,將之逐層反向傳播以更新網(wǎng)絡(luò)參數(shù)。
▲ 圖1 端到端訓(xùn)練(End-to-End Training)
盡管端到端訓(xùn)練在大量任務(wù)中都穩(wěn)定地表現(xiàn)出了良好的效果,但其效率至少在以下兩方面仍然有待提升。其一,端到端訓(xùn)練需要在網(wǎng)絡(luò)前傳時將每一層的輸出進(jìn)行存儲,并在逐層反傳梯度時使用這些值,這造成了極大的顯存開銷,如下圖所示。
▲ 圖2 端到端訓(xùn)練具有較大的的顯存開銷
其二,對整個網(wǎng)絡(luò)進(jìn)行前傳-->反傳的這一范式是一個固有的線性過程。前傳時深層網(wǎng)絡(luò)必須等待淺層網(wǎng)絡(luò)的計算完成后才能開始自身的前傳過程;同理,反傳時淺層網(wǎng)絡(luò)需要等待來自深層網(wǎng)絡(luò)的梯度信號才能進(jìn)行自身的運(yùn)算。這兩點(diǎn)線性的限制使得端到端訓(xùn)練很難進(jìn)行并行化以進(jìn)一步的提升效率。
▲ 圖3 端到端訓(xùn)練難以并行化
為了解決或緩解上述兩點(diǎn)低效的問題,一個可能的方案是使用局部監(jiān)督學(xué)習(xí),即將網(wǎng)絡(luò)拆分為若干個局部模塊(local module),并在每個模塊的末端添加一個局部損失,利用這些局部損失產(chǎn)生監(jiān)督信號分別訓(xùn)練各個局部模塊,注意不同模塊間沒有梯度上的聯(lián)通。下圖給出了一個將網(wǎng)絡(luò)拆分為兩段的例子。
▲ 圖4 局部監(jiān)督學(xué)習(xí)(Locally Supervised Learning)
相較于端到端訓(xùn)練的兩點(diǎn)不足,局部監(jiān)督學(xué)習(xí)在效率上先天具有顯著優(yōu)勢。其一,我們一次只需保存一個局部模塊內(nèi)的中間層輸出值,待此模塊完成反向傳播后,即可釋放存儲空間,進(jìn)而復(fù)用同樣的空間用以存儲下一個局部模塊的中間層輸出值,如下圖所示。簡言之,理論上顯存開銷隨局部模塊數(shù)呈指數(shù)級下降。
▲ 圖5 局部監(jiān)督學(xué)習(xí)可有效降低顯存開銷
其二,不同局部模塊的反向傳播過程并沒有必然的前后依賴關(guān)系,在工程實現(xiàn)上,不同模塊的訓(xùn)練可以自然的并行完成,例如分別使用不同的 GPU,如下圖所示。
▲ 圖6 局部監(jiān)督學(xué)習(xí)易于并行化完成
問題分析與假設(shè)
相信大家看到這里,都會有一個問題:既然局部監(jiān)督學(xué)習(xí)的效率自然地高于端到端訓(xùn)練,為什么它現(xiàn)在沒有被大規(guī)模應(yīng)用呢?其問題在于,局部監(jiān)督學(xué)習(xí)往往會損害網(wǎng)絡(luò)的整體性能。
以圖片識別為例,考慮一種簡單自然的情況,我們使用標(biāo)準(zhǔn)的線性分類器 +SoftMax+ 交叉熵作為每個局部模塊的損失函數(shù),在 CIFAR-10 數(shù)據(jù)集上使用局部監(jiān)督學(xué)習(xí)訓(xùn)練 ResNet-32,結(jié)果如下所示,其中??代表局部模塊的數(shù)目。可以看出隨著 值的增長,網(wǎng)絡(luò)的測試誤差急劇上升。
▲ 圖7 局部監(jiān)督學(xué)習(xí)傾向于損害網(wǎng)絡(luò)性能
若能解決性能下降的問題,局部監(jiān)督學(xué)習(xí)就有可能作為一種更為高效的訓(xùn)練范式而取代端到端訓(xùn)練。出于這一點(diǎn),我們探究和分析了這一問題的原因。
上述局部監(jiān)督學(xué)習(xí)和端到端訓(xùn)練的一個顯著的不同點(diǎn)在于,前者對網(wǎng)絡(luò)的中間層特征直接加入了與任務(wù)直接相關(guān)的監(jiān)督信號,從這一點(diǎn)出發(fā),一個自然的疑問是,由此引發(fā)的中間層特征在任務(wù)相關(guān)行為上的區(qū)別是怎樣的呢?
因此,我們固定了圖 7 中得到的模型,使用網(wǎng)絡(luò)每層的特征訓(xùn)練了一個線性分類器,其測試誤差如下圖右側(cè)所示。其中,橫軸代碼取用特征的網(wǎng)絡(luò)層數(shù),縱軸代表測試誤差,不同的曲線對應(yīng)于不同的? 取值, 表示端到端的情形。
▲ 圖8 中間層特征的線性可分性
從結(jié)果中可以觀察到一個明顯的現(xiàn)象:局部監(jiān)督學(xué)習(xí)所得到的中間層特征在淺層時就體現(xiàn)出了極好的線性可分性,但當(dāng)特征進(jìn)一步經(jīng)過更深的網(wǎng)絡(luò)層時,其線性可分性卻沒有得到進(jìn)一步的增長;相比而言,盡管在淺層時幾乎線性不可分,端到端訓(xùn)練得到的中間層特征隨著層數(shù)的加深可分性逐漸增強(qiáng),最終取得了更低的測試誤差。
于是便產(chǎn)生了一個非常有趣的問題:局部監(jiān)督學(xué)習(xí)中,深層網(wǎng)絡(luò)使用了分辨性遠(yuǎn)遠(yuǎn)強(qiáng)于端到端訓(xùn)練的特征,為何它得到的最終效果卻遜于端到端訓(xùn)練?難道基于可分性已經(jīng)很強(qiáng)的特征,訓(xùn)練網(wǎng)絡(luò)以進(jìn)一步提升其線性可分性,不應(yīng)該得到更好的最終結(jié)果嗎?這似乎與一些之前的觀察(例如 deeply supervised net)矛盾。
為了解答這個疑問,我們進(jìn)一步從信息的角度探究網(wǎng)絡(luò)特征在可分性之外的區(qū)別。我們分別估計了中間層特征??與輸入數(shù)據(jù)??和任務(wù)標(biāo)簽??之間的互信息??和?,并以此作為? 中包含的全部信息和任務(wù)相關(guān)信息的度量指標(biāo)。
▲ 圖9 估算互信息
其結(jié)果如下圖所示,其中橫軸為取用信息的層數(shù),縱軸表示估計值。從中不難看出,端到端訓(xùn)練的網(wǎng)絡(luò)中,特征所包含的總信息量逐層減少,但任務(wù)相關(guān)信息維持不變,說明網(wǎng)絡(luò)逐層剔除了與任務(wù)無關(guān)的信息。與之形成鮮明對比的是,局部監(jiān)督學(xué)習(xí)得到的網(wǎng)絡(luò)在淺層就丟失了大量的任務(wù)相關(guān)信息,特征所包含的總信息量也急劇下降。
我們猜測,這一現(xiàn)象的原因在于,僅憑淺層網(wǎng)絡(luò)難以如全部網(wǎng)絡(luò)一般有效分離和利用所有任務(wù)相關(guān)信息,因此索性去丟棄部分無法利用的信息換取局部訓(xùn)練損失的降低。而在這種情況下,網(wǎng)絡(luò)深層接收到的特征相較網(wǎng)絡(luò)原始輸入本就缺少關(guān)鍵信息,自然難以基于其建立更有效的表征,也就難以取得更好的最終性能。
▲ 圖10 中間層特征包含的信息
基于上述觀察,我們可以總結(jié)得到:局部監(jiān)督學(xué)習(xí)之所以會損害網(wǎng)絡(luò)的整體性能,是因為其傾向于使網(wǎng)絡(luò)在淺層丟失與任務(wù)相關(guān)的信息,從而使得深層網(wǎng)絡(luò)空有更多的參數(shù)和更大的容量,卻因輸入特征先天不足而無用武之地。
方法詳述
為了解決損失信息的問題,本文提出了一種專為局部監(jiān)督學(xué)習(xí)定制的損失函數(shù):InfoPro。首先,我們引入一個基本模型。如下圖所示,我們假設(shè)訓(xùn)練數(shù)據(jù)受到兩個隨機(jī)變量影響,其一是任務(wù)標(biāo)簽?,決定我們所關(guān)心的主體內(nèi)容;其二是無關(guān)變量?,用于決定數(shù)據(jù)中與任務(wù)無關(guān)的部分,例如背景、視角、天氣等。
▲ 圖11 變量作用關(guān)系假設(shè)
基于上述變量設(shè)置,我們將 InfoPro 損失函數(shù)定義為下面的結(jié)合形式。它用于作為局部監(jiān)督信號訓(xùn)練局部模塊,由兩項組成。第一項用于推動局部模塊向前傳遞所有信息;在第二項中,我們使用一個滿足特殊條件無關(guān)變量?來建模中間層特征中的全部任務(wù)無關(guān)信息(無用信息),在此基礎(chǔ)上迫使局部模塊剔除這些與任務(wù)無關(guān)的信息。
▲ 圖12 InfoPro損失函數(shù)
InfoPro 與端到端訓(xùn)練和在 2. Analysis 中所述的簡單局部監(jiān)督學(xué)習(xí)(Greedy Supervised Learning)的對比如下圖所示。簡言之,InfoPro 的目標(biāo)是使得局部模塊能夠在保證向前傳遞全部有價值信息的條件下,盡可能丟棄特征中的無用信息,以解決局部監(jiān)督學(xué)習(xí)在淺層丟失任務(wù)相關(guān)信息、影響網(wǎng)絡(luò)最終性能的問題。
事實上,這也是我們前面觀察到的、端到端訓(xùn)練對網(wǎng)絡(luò)淺層的影響形式。InfoPro 與其它局部學(xué)習(xí)方法最大的區(qū)別在于它是非貪婪的,并不直接對局部的任務(wù)相關(guān)行為(如Greedy Supervised Learning 中基于局部特征的分類損失)做出直接約束。
▲ 圖13 3種訓(xùn)練算法的對比
在具體實現(xiàn)上,由于 InfoPro 損失的第二項比較難以估算,我們推導(dǎo)出了其的一個易于計算的上界,如下圖所示:
▲ 圖14 InfoPro損失的一個易于計算的上界
關(guān)于這一上界的具體推導(dǎo)過程、一些數(shù)學(xué)性質(zhì)和其實際上的計算方式,由于流程比較復(fù)雜且不關(guān)鍵,不在此贅述,歡迎感興趣的讀者參閱我們的文章~
實驗結(jié)果
●?在不同局部模塊數(shù)目的條件下,穩(wěn)定勝過 baseline
●?大量節(jié)省顯存,且不引入顯著的額外計算/時間開銷,效果相較端到端訓(xùn)練略有提升
●?ImageNet 大規(guī)模圖像識別任務(wù)上的結(jié)果,節(jié)省顯存的效果同樣顯著,效果略有提升
●?Cityscapes 語義分割實驗結(jié)果,除節(jié)省顯存方面的作用外,我們還證明了,在相同的顯存限制下,InfoPro 可以使用更大的 batch size 或更大分辨率的輸入圖片
結(jié)語
總結(jié)來說,這項工作的要點(diǎn)在于:(1)從效率的角度反思端到端訓(xùn)練范式;(2)指出了局部監(jiān)督學(xué)習(xí)相較于端到端的缺陷在于損失網(wǎng)絡(luò)性能,并從信息的角度分析了其原因;(3)在理論上提出了初步解決方案,并探討了具體實現(xiàn)方法。
歡迎大家 follow 我們的工作~
@inproceedings{wang2021revisiting,title?=?{Revisiting?Locally?Supervised?Learning:?an?Alternative?to?End-to-end?Training},author?=?{Yulin?Wang?and?Zanlin?Ni?and?Shiji?Song?and?Le?Yang?and?Gao?Huang},booktitle?=?{International?Conference?on?Learning?Representations?(ICLR)},year?=?{2021},url?=?{https://openreview.net/forum?id=fAbkE6ant2} }如有任何問題,歡迎留言或者給我發(fā)郵件,附上我的主頁鏈接:
http://www.rainforest-wang.cool
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識的人。
總有一些你不認(rèn)識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標(biāo)準(zhǔn):
? 稿件確系個人原創(chuàng)作品,來稿需注明作者個人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請在投稿時提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會添加“原創(chuàng)”標(biāo)志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨(dú)在附件中發(fā)送?
? 請留下即時聯(lián)系方式(微信或手機(jī)),以便我們在編輯發(fā)布時和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學(xué)術(shù)平臺。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結(jié)
以上是生活随笔為你收集整理的ICLR 2021 | 显存不够?不妨抛弃端到端训练的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 怎样消灭家里的蟑螂?
- 下一篇: Python 和 C/C++ 拓展程序的