【AI不惑境】模型压缩中知识蒸馏技术原理及其发展现状和展望
大家好,這是專欄《AI不惑境》的第十一篇文章,講述知識(shí)蒸餾相關(guān)的內(nèi)容。
進(jìn)入到不惑境界,就是向高手邁進(jìn)的開始了,在這個(gè)境界需要自己獨(dú)立思考。如果說(shuō)學(xué)習(xí)是一個(gè)從模仿,到追隨,到創(chuàng)造的過(guò)程,那么到這個(gè)階段,應(yīng)該躍過(guò)了模仿和追隨的階段,進(jìn)入了創(chuàng)造的階段。從這個(gè)境界開始,講述的問(wèn)題可能不再有答案,更多的是激發(fā)大家一起來(lái)思考。
作者&編輯 | 言有三
知識(shí)蒸餾是非常經(jīng)典的基于遷移學(xué)習(xí)的模型壓縮技術(shù),在學(xué)術(shù)界的研究非常活躍,工業(yè)界也有許多的應(yīng)用和較大的潛力,本文給大家梳理知識(shí)蒸餾的核心技術(shù),發(fā)展現(xiàn)狀,未來(lái)展望以及學(xué)習(xí)資源推薦。
1 知識(shí)蒸餾基礎(chǔ)
1.1 什么是知識(shí)蒸餾
一般地,大模型往往是單個(gè)復(fù)雜網(wǎng)絡(luò)或者是若干網(wǎng)絡(luò)的集合,擁有良好的性能和泛化能力,而小模型因?yàn)榫W(wǎng)絡(luò)規(guī)模較小,表達(dá)能力有限。利用大模型學(xué)習(xí)到的知識(shí)去指導(dǎo)小模型訓(xùn)練,使得小模型具有與大模型相當(dāng)?shù)男阅?#xff0c;但是參數(shù)數(shù)量大幅降低,從而可以實(shí)現(xiàn)模型壓縮與加速,就是知識(shí)蒸餾與遷移學(xué)習(xí)在模型優(yōu)化中的應(yīng)用。
Hinton等人最早在文章“Distilling the knowledge in a neural network”[1]中提出了知識(shí)蒸餾這個(gè)概念,其核心思想是一旦復(fù)雜網(wǎng)絡(luò)模型訓(xùn)練完成,便可以用另一種訓(xùn)練方法從復(fù)雜模型中提取出來(lái)更小的模型,因此知識(shí)蒸餾框架通常包含了一個(gè)大模型(被稱為teacher模型),和一個(gè)小模型(被稱為student模型)。
1.2 為什么要進(jìn)行知識(shí)蒸餾
以計(jì)算機(jī)視覺(jué)模型的訓(xùn)練為例,我們經(jīng)常用在ImageNet上訓(xùn)練的模型作為預(yù)訓(xùn)練模型,之所以可以這樣做,是因?yàn)樯疃葘W(xué)習(xí)模型在網(wǎng)絡(luò)淺層學(xué)習(xí)的知識(shí)是圖像的色彩和邊緣等底層信息,某一個(gè)數(shù)據(jù)集學(xué)習(xí)到的信息也可以應(yīng)用于其他領(lǐng)域。
那實(shí)際上知識(shí)蒸餾或者說(shuō)遷移學(xué)習(xí)的必要性在哪里?
(1) 數(shù)據(jù)分布差異。深度學(xué)習(xí)模型的訓(xùn)練場(chǎng)景和測(cè)試場(chǎng)景往往有分布差異,以自動(dòng)駕駛領(lǐng)域?yàn)槔?#xff0c;大部分?jǐn)?shù)據(jù)集的采集都是基于白天,光照良好的天氣條件下,在這樣的數(shù)據(jù)集上訓(xùn)練的模型,當(dāng)將其用于黑夜,風(fēng)雪等場(chǎng)景時(shí),很有可能會(huì)無(wú)法正常工作,從而使得模型的實(shí)用性能非常受限。因此,必須考慮模型從源域到目標(biāo)域的遷移能力。
(2) 受限的數(shù)據(jù)規(guī)模。數(shù)據(jù)的標(biāo)注成本是非常高的,導(dǎo)致很多任務(wù)只能用少量的標(biāo)注進(jìn)行模型的訓(xùn)練。以醫(yī)學(xué)領(lǐng)域?yàn)榈湫痛?#xff0c;數(shù)據(jù)集的規(guī)模并不大,因此在真正專用的模型訓(xùn)練之前往往需要先在通用任務(wù)上進(jìn)行預(yù)訓(xùn)練。
(3) 通用與垂直領(lǐng)域。雖然我們可以訓(xùn)練許多通用的模型,但是真實(shí)需求是非常垂直或者說(shuō)個(gè)性化的,比如ImageNet存在1000類,但是我們可能只需要用到其中若干類。此時(shí)就可以基于1000類ImageNet模型進(jìn)行知識(shí)遷移,而不需要完全從頭開始訓(xùn)練。
因此,在工業(yè)界對(duì)知識(shí)蒸餾和遷移學(xué)習(xí)也有著非常強(qiáng)烈的需求,接下來(lái)我們講解其中的主要算法。
2 知識(shí)蒸餾主要算法
知識(shí)蒸餾是對(duì)模型的能力進(jìn)行遷移,根據(jù)遷移的方法不同可以簡(jiǎn)單分為基于目標(biāo)驅(qū)動(dòng)的算法,基于特征匹配的算法兩個(gè)大的方向,下面我們對(duì)其進(jìn)行介紹。
2.1 知識(shí)蒸餾基本框架
Hinton最早在文章“Distilling the knowledge in a neural network”中提出了知識(shí)蒸餾的概念,即knowledge distilling,對(duì)后續(xù)的許多算法都產(chǎn)生了影響,其框架示意圖如下:
從上圖中可以看出,包括一個(gè)teacher model和一個(gè)student model,teacher model需要預(yù)先訓(xùn)練好,使用的就是標(biāo)準(zhǔn)分類softmax損失,但是它的輸出使用帶溫度參數(shù)T的softmax函數(shù)進(jìn)行映射,如下:
當(dāng)T=1時(shí),就是softmax本身。當(dāng)T>1,稱之為soft softmax,T越大,因?yàn)檩斎雤k產(chǎn)生的概率f(zk)差異就會(huì)越小。
之所以要這么做,其背后的思想是當(dāng)訓(xùn)練好一個(gè)模型之后,模型為所有的誤標(biāo)簽都分配了很小的概率。然而實(shí)際上對(duì)于不同的錯(cuò)誤標(biāo)簽,其被分配的概率仍然可能存在數(shù)個(gè)量級(jí)的懸殊差距。這個(gè)差距,在softmax中直接就被忽略了,但這其實(shí)是一部分有用的信息。
訓(xùn)練的時(shí)候小模型有兩個(gè)損失,一個(gè)是與真實(shí)標(biāo)簽的softmax損失,一個(gè)是與teacher model的蒸餾損失,定義為KL散度。
當(dāng)teacher model和student model各自的預(yù)測(cè)概率為pi,qi時(shí),其蒸餾損失部分梯度傳播如下:
可以看出形式非常的簡(jiǎn)單,梯度為兩者預(yù)測(cè)概率之差,這就是最簡(jiǎn)單的知識(shí)蒸餾框架。
2.2 優(yōu)化目標(biāo)驅(qū)動(dòng)的知識(shí)蒸餾框架
Hinton等人提出的框架是在模型最后的預(yù)測(cè)端,讓student模型學(xué)習(xí)到與teacher模型的知識(shí),這可以稱之為直接使用優(yōu)化目標(biāo)進(jìn)行驅(qū)動(dòng)的框架,類似的還有ProjectionNet[2]。
PrjojectNet同時(shí)訓(xùn)練一個(gè)大模型和一個(gè)小模型,兩者的輸入都是樣本,其中大模型就是普通的CNN網(wǎng)絡(luò),而小模型會(huì)對(duì)輸入首先進(jìn)行特征投影。每一個(gè)投影矩陣P都對(duì)應(yīng)了一個(gè)映射,由一個(gè)d-bit長(zhǎng)的向量表示,其中每一個(gè)bit為0或者1,這是一個(gè)更加稀疏的表達(dá)。特征用這種方法簡(jiǎn)化后自然就可以使用更加輕量的網(wǎng)絡(luò)的結(jié)構(gòu)進(jìn)行訓(xùn)練。
那么怎么完成這個(gè)過(guò)程呢?文中使用的是locality sensitive hashing(LSH)算法,這是一種聚類任務(wù)中常用的降維的算法。
優(yōu)化目標(biāo)包含了3部分,分別是大模型的損失,投影損失,以及大模型和小模型的預(yù)測(cè)損失,全部使用交叉熵,各自定義如下:
基于優(yōu)化目標(biāo)驅(qū)動(dòng)的方法其思想是非常直觀,就是結(jié)果導(dǎo)向型,中間怎么實(shí)現(xiàn)的不關(guān)心,對(duì)它進(jìn)行改進(jìn)的一個(gè)有趣方向是GAN的運(yùn)用。
2.3 特征匹配的知識(shí)蒸餾框架
結(jié)果導(dǎo)向型的知識(shí)蒸餾框架的具體細(xì)節(jié)是難以控制的,會(huì)讓訓(xùn)練變得不穩(wěn)定且緩慢。一種更直觀的方式是將teacher模型和student模型的特征進(jìn)行約束,從而保證student模型確實(shí)繼承了teacher模型的知識(shí),其中一個(gè)典型代表就是FitNets[3],FitNets將比較淺而寬的Teacher模型的知識(shí)遷移到更窄更深的Student模型上,框架如下:
FitNets背后的思想是,用網(wǎng)絡(luò)的中間層的特征進(jìn)行匹配,不僅僅是在輸出端。
它的訓(xùn)練包含了兩個(gè)階段:
第一階段就是根據(jù)Teacher模型的損失來(lái)指導(dǎo)預(yù)訓(xùn)練Student模型。記Teacher網(wǎng)絡(luò)的某一中間層的權(quán)值Wt為Whint,意為指導(dǎo)的意思。Student網(wǎng)絡(luò)的某一中間層的權(quán)值Ws為Wguided,即被指導(dǎo)的意思,在訓(xùn)練之初Student網(wǎng)絡(luò)進(jìn)行隨機(jī)初始化。
我們需要學(xué)習(xí)一個(gè)映射函數(shù)Wr使得Wguided的維度匹配Whint,得到Ws',并最小化兩者網(wǎng)絡(luò)輸出的MSE差異作為損失,如下:
第二個(gè)訓(xùn)練階段,就是對(duì)整個(gè)網(wǎng)絡(luò)進(jìn)行知識(shí)蒸餾訓(xùn)練,與上述Hinton等人提出的策略一致。
不過(guò)FitNet直接將特征值進(jìn)行了匹配,先驗(yàn)約束太強(qiáng),有的框架對(duì)激活值進(jìn)行了歸一化[4]。
基于特征空間進(jìn)行匹配的方法其實(shí)是知識(shí)蒸餾的主流,類似的方法非常多,包括注意力機(jī)制的使用[5],類似于風(fēng)格遷移算法的特征匹配[6]等。
3 知識(shí)蒸餾算法的展望
上一節(jié)我們介紹了知識(shí)蒸餾的基本方法,當(dāng)然知識(shí)蒸餾還有非常多有意思的研究方向,這里我們介紹其中幾個(gè)。
3.1 不壓縮模型
機(jī)器學(xué)習(xí)模型要解決的問(wèn)題如下,其中y是預(yù)測(cè)值,x是輸入,L是優(yōu)化目標(biāo),θ1是優(yōu)化參數(shù)。
因?yàn)樯疃葘W(xué)習(xí)模型沒(méi)有解析解,往往無(wú)法得到最優(yōu)解,我們經(jīng)常會(huì)通過(guò)添加一些正則項(xiàng)來(lái)促使模型達(dá)到更好的性能。
Born Again Neural Networks[7]框架思想是通過(guò)增加同樣的模型架構(gòu),并且重新進(jìn)行優(yōu)化,以增加一個(gè)模型為例,要解決的問(wèn)題如下:
具體的流程就是:
(1) 訓(xùn)練一個(gè)教師模型使其收斂到較好的局部值。
(2) 對(duì)與教師模型結(jié)構(gòu)相同的學(xué)生模型進(jìn)行初始化,其優(yōu)化目標(biāo)包含兩部分,一部分是要匹配教師模型的輸出分布,比如采用KL散度,另一部分就是與教師模型訓(xùn)練時(shí)同樣的目標(biāo),即數(shù)據(jù)集的預(yù)測(cè)真值。
然后通過(guò)下面這樣的流程,一步一步往下傳,所以被形象地命名為“born again”。
類似的框架還有Net2Net,network morphism等。
3.2 去掉teacher模型
一般知識(shí)蒸餾框架都需要包括一個(gè)Teacher模型和一個(gè)Student模型,而Deep mutual learning[8]則沒(méi)有Teacher模型,它通過(guò)多個(gè)小模型進(jìn)行協(xié)同訓(xùn)練,框架示意圖如下。
Deep mutual learning在訓(xùn)練的過(guò)程中讓兩個(gè)學(xué)生網(wǎng)絡(luò)相互學(xué)習(xí),每一個(gè)網(wǎng)絡(luò)都有兩個(gè)損失。一個(gè)是任務(wù)本身的損失,另外一個(gè)就是KL散度。由于KL散度是非對(duì)稱的,所以兩個(gè)網(wǎng)絡(luò)的散度會(huì)不同。
相比單獨(dú)訓(xùn)練,每一個(gè)模型可以取得更高的精度。值得注意的是,就算是兩個(gè)結(jié)構(gòu)完全一樣的模型,也會(huì)學(xué)習(xí)到不同的特征表達(dá)。
3.3 與其他框架的結(jié)合
在進(jìn)行知識(shí)蒸餾時(shí),我們通常假設(shè)teacher模型有更好的性能,而student模型是一個(gè)壓縮版的模型,這不就是模型壓縮嗎?與模型剪枝,量化前后的模型對(duì)比是一樣的。所以知識(shí)蒸餾也被用于與相關(guān)技術(shù)進(jìn)行結(jié)合,apprentice[9]框架是一個(gè)代表。
網(wǎng)絡(luò)結(jié)構(gòu)如上圖所示,Teacher模型是一個(gè)全精度模型,Apprentice模型是一個(gè)低精度模型。
當(dāng)然模型蒸餾還有一些其他方向,以及對(duì)其中每一個(gè)方向的深入解讀。對(duì)模型蒸餾感興趣的同學(xué),歡迎到有三AI知識(shí)星球的網(wǎng)絡(luò)結(jié)構(gòu)1000變-模型壓縮-模型蒸餾板塊進(jìn)行學(xué)習(xí),數(shù)十期內(nèi)容定能滿足你的求知欲。
掃碼即可加入,了解有三AI知識(shí)星球詳情請(qǐng)閱讀以下文章。
【雜談】有三AI知識(shí)星球一周年了!為什么公眾號(hào)+星球才是完整的?
參考文獻(xiàn)
[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.
[2] Ravi S. ProjectionNet: Learning Efficient On-Device Deep Networks Using Neural Projections[J]. arXiv: Learning, 2017.
[3] Romero A, Ballas N, Kahou S E, et al. Fitnets: Hints for thin deep nets[J]. arXiv preprint arXiv:1412.6550, 2014.
[4] Huang Z, Wang N. Like What You Like: Knowledge Distill via Neuron Selectivity Transfer.[J]. arXiv: Computer Vision and Pattern Recognition, 2017.
[5] Zagoruyko S, Komodakis N. Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer[C]. international conference on learning representations, 2017.
[6] Yim J, Joo D, Bae J, et al. A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning[C]. computer vision and pattern recognition, 2017: 7130-7138.
[7] Furlanello T, Lipton Z C, Tschannen M, et al. Born Again Neural Networks[C]. international conference on machine learning, 2018: 1602-1611.
[8] Zhang Y, Xiang T, Hospedales T M, et al. Deep Mutual Learning[C]. computer vision and pattern recognition, 2018: 4320-4328.
[9] Mishra A K, Marr D. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy[C]. international conference on learning representations, 2018.
總結(jié)
本次我們總結(jié)了模型蒸餾的核心技術(shù),并對(duì)其重要方向進(jìn)行了展望,推薦了相關(guān)的學(xué)習(xí)資源,下一期我們將介紹AutoML在模型優(yōu)化上的進(jìn)展。
有三AI秋季劃-模型優(yōu)化組
如果你想系統(tǒng)性地學(xué)習(xí)模型優(yōu)化相關(guān)的理論和實(shí)踐,并獲得持續(xù)的指導(dǎo),歡迎加入有三AI秋季劃-模型優(yōu)化組,系統(tǒng)性地學(xué)習(xí)數(shù)據(jù)使用,模型使用和調(diào)參,模型性能分析,緊湊模型設(shè)計(jì),模型剪枝,模型量化,模型部署,NAS等內(nèi)容。
模型優(yōu)化組介紹和往期的一些學(xué)習(xí)內(nèi)容總結(jié)請(qǐng)參考閱讀以下文章:
【通知】如何讓你的2020年秋招CV項(xiàng)目經(jīng)歷更加硬核,可深入學(xué)習(xí)有三秋季劃4大領(lǐng)域32個(gè)方向
【總結(jié)】有三AI秋季劃模型優(yōu)化組3月直播講了哪些內(nèi)容,為什么每一個(gè)從事深度學(xué)習(xí)的同學(xué)都應(yīng)該掌握模型優(yōu)化的內(nèi)容
轉(zhuǎn)載文章請(qǐng)后臺(tái)聯(lián)系
侵權(quán)必究
往期精選
【完結(jié)】深度學(xué)習(xí)CV算法工程師從入門到初級(jí)面試有多遠(yuǎn),大概是25篇文章的距離
【完結(jié)】?jī)?yōu)秀的深度學(xué)習(xí)從業(yè)者都有哪些優(yōu)秀的習(xí)慣
【完結(jié)】給新手的12大深度學(xué)習(xí)開源框架快速入門項(xiàng)目
【完結(jié)】總結(jié)12大CNN主流模型架構(gòu)設(shè)計(jì)思想
【知乎直播】千奇百怪的CNN網(wǎng)絡(luò)架構(gòu)等你來(lái)
【AI不惑境】數(shù)據(jù)壓榨有多狠,人工智能就有多成功
【AI不惑境】網(wǎng)絡(luò)深度對(duì)深度學(xué)習(xí)模型性能有什么影響?
【AI不惑境】網(wǎng)絡(luò)的寬度如何影響深度學(xué)習(xí)模型的性能?
【AI不惑境】學(xué)習(xí)率和batchsize如何影響模型的性能?
【AI不惑境】殘差網(wǎng)絡(luò)的前世今生與原理
【AI不惑境】移動(dòng)端高效網(wǎng)絡(luò),卷積拆分和分組的精髓
【AI不惑境】深度學(xué)習(xí)中的多尺度模型設(shè)計(jì)
【AI不惑境】計(jì)算機(jī)視覺(jué)中注意力機(jī)制原理及其模型發(fā)展和應(yīng)用
【AI不惑境】模型剪枝技術(shù)原理及其發(fā)展現(xiàn)狀和展望
【AI不惑境】模型量化技術(shù)原理及其發(fā)展現(xiàn)狀和展望
總結(jié)
以上是生活随笔為你收集整理的【AI不惑境】模型压缩中知识蒸馏技术原理及其发展现状和展望的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【通知】如何让你的2020年秋招CV项目
- 下一篇: 【AI不惑境】AutoML在深度学习模型