[深度学习]知识蒸馏技术
一 知識蒸餾(Knowledge Distillation)介紹
名詞解釋
- teacher - 原始模型或模型ensemble
- student - 新模型
- transfer set - 用來遷移teacher知識、訓練student的數據集合
- soft target - teacher輸出的預測結果(一般是softmax之后的概率)
- hard target - 樣本原本的標簽
- temperature - 蒸餾目標函數中的超參數
- born-again network - 蒸餾的一種,指student和teacher的結構和尺寸完全一樣
- teacher annealing - 防止student的表現被teacher限制,在蒸餾時逐漸減少soft targets的權重
1.1 什么是知識蒸餾?
在化學中,蒸餾是一種有效的分離不同沸點組分的方法,大致步驟是先升溫使低沸點的組分汽化,然后降溫冷凝,達到分離出目標物質的目的。化學蒸餾條件:(1)蒸餾的液體是混合物;(2)各組分沸點不同。
蒸餾的液體是混合物,這個混合物一定是包含了各種組分,即在我們今天講的知識蒸餾中指原模型包含大量的知識。各組分沸點不同,蒸餾時要根據目標物質的沸點設置蒸餾溫度,即在我們今天講的知識蒸餾中也有“溫度”的概念,那這個“溫度“代表了什么,又是如何選取合適的”溫度“?這里先埋下伏筆,在文中給大家揭曉答案。
?進入我們今天正式的主題,到底什么是知識蒸餾?一般地,大模型往往是單個復雜網絡或者是若干網絡的集合,擁有良好的性能和泛化能力,而小模型因為網絡規模較小,表達能力有限。因此,可以利用大模型學習到的知識去指導小模型訓練,使得小模型具有與大模型相當的性能,但是參數數量大幅降低,從而實現模型壓縮與加速,這就是知識蒸餾與遷移學習在模型優化中的應用。
Hinton等人最早在文章《Distilling the Knowledge in a Neural Network》中提出了知識蒸餾這個概念,其核心思想是先訓練一個復雜網絡模型,然后使用這個復雜網絡的輸出和數據的真實標簽去訓練一個更小的網絡,因此知識蒸餾框架通常包含了一個復雜模型(被稱為Teacher模型)和一個小模型(被稱為Student模型)。
1.2 為什么要有知識蒸餾?
好模型的目標不是擬合訓練數據,而是學習如何泛化到新的數據。所以蒸餾的目標是讓student學習到teacher的泛化能力,理論上得到的結果會比單純擬合訓練數據的student要好。另外,對于分類任務,如果soft targets的熵比hard targets高,那顯然student會學習到更多的信息。
深度學習在計算機視覺、語音識別、自然語言處理等內的眾多領域中均取得了令人難以置信的性能。但是,大多數模型在計算上過于昂貴,無法在移動端或嵌入式設備上運行。因此需要對模型進行壓縮,且知識蒸餾是模型壓縮中重要的技術之一。
1. 提升模型精度
如果對目前的網絡模型A的精度不是很滿意,那么可以先訓練一個更高精度的teacher模型B(通常參數量更多,時延更大),然后用這個訓練好的teacher模型B對student模型A進行知識蒸餾,得到一個更高精度的A模型。
2. 降低模型時延,壓縮網絡參數
如果對目前的網絡模型A的時延不滿意,可以先找到一個時延更低,參數量更小的模型B,通常來講,這種模型精度也會比較低,然后通過訓練一個更高精度的teacher模型C來對這個參數量小的模型B進行知識蒸餾,使得該模型B的精度接近最原始的模型A,從而達到降低時延的目的。
3. 標簽之間的域遷移
假如使用狗和貓的數據集訓練了一個teacher模型A,使用香蕉和蘋果訓練了一個teacher模型B,那么就可以用這兩個模型同時蒸餾出一個可以識別狗、貓、香蕉以及蘋果的模型,將兩個不同域的數據集進行集成和遷移。
因此,在工業界中對知識蒸餾和遷移學習也有著非常強烈的需求。
模型壓縮大體上可以分為 5 種:
-
模型剪枝:即移除對結果作用較小的組件,如減少 head 的數量和去除作用較少的層,共享參數等,ALBERT屬于這種;
-
量化:比如將 float32 降到 float8;
-
知識蒸餾:將 teacher 的能力蒸餾到 student上,一般 student 會比 teacher 小。我們可以把一個大而深的網絡蒸餾到一個小的網絡,也可以把集成的網絡蒸餾到一個小的網絡上。
-
參數共享:通過共享參數,達到減少網絡參數的目的,如 ALBERT 共享了 Transformer 層;
-
參數矩陣近似:通過矩陣的低秩分解或其他方法達到降低矩陣參數的目的;
1.3 這與從頭開始訓練模型有何不同?
顯然,對于更復雜的模型,理論搜索空間要大于較小網絡的搜索空間。但是,如果我們假設使用較小的網絡可以實現相同(甚至相似)的收斂,則教師網絡的收斂空間應與學生網絡的解空間重疊。
不幸的是,僅此一項并不能保證學生網絡在同一位置收斂。學生網絡的收斂可能與教師網絡的收斂大不相同。但是,如果指導學生網絡復制教師網絡的行為(教師網絡已經在更大的解空間中進行了搜索),則可以預期其收斂空間與原始教師網絡收斂空間重疊。
2. 知識蒸餾方式
2.1 知識蒸餾基本框架
知識蒸餾采取Teacher-Student模式:將復雜且大的模型作為Teacher,Student模型結構較為簡單,用Teacher來輔助Student模型的訓練,Teacher學習能力強,可以將它學到的知識遷移給學習能力相對弱的Student模型,以此來增強Student模型的泛化能力。復雜笨重但是效果好的Teacher模型不上線,就單純是個導師角色,真正部署上線進行預測任務的是靈活輕巧的Student小模型。
知識蒸餾是對模型的能力進行遷移,根據遷移的方法不同可以簡單分為基于目標蒸餾(也稱為Soft-target蒸餾或Logits方法蒸餾)和基于特征蒸餾的算法兩個大的方向,下面我們對其進行介紹。
2.2 目標蒸餾-Logits方法
目標蒸餾方法中最經典的論文就是來自于2015年Hinton發表的一篇神作《Distilling the Knowledge in a Neural Network》。下面我們以這篇神作為例,給大家講講目標蒸餾方法的原理。
在這篇論文中,Hinton將問題限定在分類問題下,分類問題的共同點是模型最后會有一個softmax層,其輸出值對應了相應類別的概率值。在知識蒸餾時,由于我們已經有了一個泛化能力較強的Teacher模型,我們在利用Teacher模型來蒸餾訓練Student模型時,可以直接讓Student模型去學習Teacher模型的泛化能力。一個很直白且高效的遷移泛化能力的方法就是:使用softmax層輸出的類別的概率來作為“Soft-target” 。
2.2.1 Hard-target 和 Soft-target
傳統的神經網絡訓練方法是定義一個損失函數,目標是使預測值盡可能接近于真實值(Hard- target),損失函數就是使神經網絡的損失值和盡可能小。這種訓練過程是對ground truth求極大似然。在知識蒸餾中,是使用大模型的類別概率作為Soft-target的訓練過程。
圖:來源于參考文獻2
-
Hard-target:原始數據集標注的 one-shot 標簽,除了正標簽為 1,其他負標簽都是 0。
-
Soft-target:Teacher模型softmax層輸出的類別概率,每個類別都分配了概率,正標簽的概率最高。
知識蒸餾用Teacher模型預測的 Soft-target 來輔助 Hard-target 訓練 Student模型的方式為什么有效呢?softmax層的輸出,除了正例之外,負標簽也帶有Teacher模型歸納推理的大量信息,比如某些負標簽對應的概率遠遠大于其他負標簽,則代表 Teacher模型在推理時認為該樣本與該負標簽有一定的相似性。而在傳統的訓練過程(Hard-target)中,所有負標簽都被統一對待。也就是說,知識蒸餾的訓練方式使得每個樣本給Student模型帶來的信息量大于傳統的訓練方式。
如在MNIST數據集中做手寫體數字識別任務,假設某個輸入的“2”更加形似"3",softmax的輸出值中"3"對應的概率會比其他負標簽類別高;而另一個"2"更加形似"7",則這個樣本分配給"7"對應的概率會比其他負標簽類別高。這兩個"2"對應的Hard-target的值是相同的,但是它們的Soft-target卻是不同的,由此我們可見Soft-target蘊含著比Hard-target更多的信息。
在使用 Soft-target 訓練時,Student模型可以很快學習到 Teacher模型的推理過程;而傳統的 Hard-target 的訓練方式,所有的負標簽都會被平等對待。因此,Soft-target 給 Student模型帶來的信息量要大于 Hard-target,并且Soft-target分布的熵相對高時,其Soft-target蘊含的知識就更豐富。同時,使用 Soft-target 訓練時,梯度的方差會更小,訓練時可以使用更大的學習率,所需要的樣本也更少。這也解釋了為什么通過蒸餾的方法訓練出的Student模型相比使用完全相同的模型結構和訓練數據只使用Hard-target的訓練方法得到的模型,擁有更好的泛化能力。
2.2.2 知識蒸餾的具體方法
在介紹知識蒸餾方法之前,首先得明白什么是Logits。我們知道,對于一般的分類問題,比如圖片分類,輸入一張圖片后,經過DNN網絡各種非線性變換,在網絡最后Softmax層之前,會得到這張圖片屬于各個類別的大小數值,某個類別的??數值越大,則模型認為輸入圖片屬于這個類別的可能性就越大。什么是Logits? 這些匯總了網絡內部各種信息后,得出的屬于各個類別的匯總分值?就是Logits,i代表第i個類別,??代表屬于第i類的可能性。因為Logits并非概率值,所以一般在Logits數值上會用Softmax函數進行變換,得出的概率值作為最終分類結果概率。Softmax一方面把Logits數值在各類別之間進行概率歸一,使得各個類別歸屬數值滿足概率分布;另外一方面,它會放大Logits數值之間的差異,使得Logits得分兩極分化,Logits得分高的得到的概率值更偏大一些,而較低的Logits數值,得到的概率值則更小。
?
神經網絡使用 softmax 層來實現 logits 向 probabilities 的轉換。原始的softmax函數:?
但是直接使用softmax層的輸出值作為soft target,這又會帶來一個問題: 當softmax輸出的概率分布熵相對較小時,負標簽的值都很接近0,對損失函數的貢獻非常小,小到可以忽略不計。因此"溫度"這個變量就派上了用場。下面的公式是加了溫度這個變量之后的softmax函數:
?T?就是溫度。當溫度?T=1?時,這就是標準的 Softmax 公式.? T越高,softmax的output probability distribution越趨于平滑,其分布的熵越大,負標簽攜帶的信息會被相對地放大,模型訓練將更加關注負標簽。
?
知識蒸餾訓練的具體方法如下圖所示,主要包括以下幾個步驟:
?
?
?訓練Teacher的過程很簡單,我們把第2步和第3步過程統一稱為:高溫蒸餾的過程。高溫蒸餾過程的目標函數由distill loss(對應Soft-target)和Student loss(對應Hard-target)加權得到。如下所示:
參數 ?表示訓練中蒸餾損失的比重,考慮到網絡訓練中一般先學習容易樣本,然后學習困難樣本,實際訓練中?會逐漸減小,因為teacher網絡有比較高的置信度給出容易樣本類別概率分布,而對困難樣本置信度不那么高,所有網絡訓練后期使用更多地標注信息做監督會更好。?
?
?第二部分Hard Loss?的必要性其實很好理解:Teacher模型也有一定的錯誤率,使用ground truth可以有效降低錯誤被傳播給Student模型的可能性。打個比喻,老師雖然學識遠遠超過學生,但是他仍然有出錯的可能,而這時候如果學生在老師的教授之外,可以同時參考到標準答案,就可以有效地降低被老師偶爾的錯誤“帶偏”的可能性。
?
2.3 特征蒸餾
另外一種知識蒸餾思路是特征蒸餾方法,如下圖所示。它不像Logits方法那樣,Student只學習Teacher的Logits這種結果知識,而是學習Teacher網絡結構中的中間層特征。最早采用這種模式的工作來自于論文《FITNETS:Hints for Thin Deep Nets》,它強迫Student某些中間層的網絡響應,要去逼近Teacher對應的中間層的網絡響應。這種情況下,Teacher中間特征層的響應,就是傳遞給Student的知識。在此之后,出了各種新方法,但是大致思路還是這個思路,本質是Teacher將特征級知識遷移給Student。因此,接下來我們以這篇論文為主,詳細介紹特征蒸餾方法的原理。
2.3.1 主要解決的問題
這篇論文首先提出一個案例,既寬又深的模型通常需要大量的乘法運算,從而導致對內存和計算的高需求。因此,即使網絡在準確性方面是性能最高的模型,其在現實世界中的應用也受到限制。
為了解決這類問題,我們需要通過模型壓縮(也稱為知識蒸餾)將知識從復雜的模型轉移到參數較少的簡單模型。
到目前為止,知識蒸餾技術已經考慮了Student網絡與Teacher網絡有相同或更小的參數。這里有一個洞察點是,深度是特征學習的基本層面,到目前為止尚未考慮到Student網絡的深度。一個具有比Teacher網絡更多的層但每層具有較少神經元數量的Student網絡稱為“thin deep network”。
因此,該篇論文主要針對Hinton提出的知識蒸餾法進行擴展,允許Student網絡可以比Teacher網絡更深更窄,使用teacher網絡的輸出和中間層的特征作為提示,改進訓練過程和student網絡的性能。
2.3.2 模型結構
-
Student網絡不僅僅擬合Teacher網絡的Soft-target,而且擬合隱藏層的輸出(Teacher網絡抽取的特征);
-
第一階段讓Student網絡去學習Teacher網絡的隱藏層輸出(特征蒸餾);
-
第二階段使用Soft-target來訓練Student網絡(目標蒸餾)。
把“寬”且“深”的網絡蒸餾成“瘦”且“更深”的網絡,需要進行兩階段的訓練:
第一階段:首先選擇待蒸餾的中間層(即Teacher的Hint layer和Student的Guided layer),如圖中綠框和紅框所示。由于兩者的輸出尺寸可能不同,因此,在Guided layer后另外接一層卷積層,使得輸出尺寸與Teacher的Hint layer匹配。接著通過知識蒸餾的方式訓練Student網絡的Guided layer,使得Student網絡的中間層學習到Teacher的Hint layer的輸出.
就是根據Teacher模型的損失來指導預訓練Student模型。記Teacher網絡的前??層作為??,意為指導的意思。Student網絡的前??層作為?,即被指導的意思,在訓練之初Student網絡進行隨機初始化。需要學習一個映射函數??使得??的維度匹配??,得到Student模型在下一階段的參數初始化值,并最小化兩者網絡輸出的MSE差異作為損失(特征蒸餾),如下:
其中,??是教師網絡的部分層的參數(綠框);?是學生網絡的部分層的參數(紅框);?是一個全連接層,用于將兩個網絡輸出的size配齊,因為學生網絡隱藏層寬度比教師網絡窄。
第二階段:?在訓練好Guided layer之后,將當前的參數作為網絡的初始參數,利用知識蒸餾的方式訓練Student網絡的所有層參數,使Student學習Teacher的輸出。由于Teacher對于簡單任務的預測非常準確,在分類任務中近乎one-hot輸出,因此為了弱化預測輸出,使所含信息更加豐富,作者使用Hinton等人論文《Distilling knowledge in a neural network》中提出的softmax改造方法,即在softmax前引入??縮放因子,將Teacher和Student的pre-softmax輸出均除以??。也就是上面我們講的加了溫度的softmax。此時的損失函數為:
其中,??指交叉熵損失函數;?是一個可調整參數,以平衡兩個交叉熵;第一部分為Student的輸出與Ground-truth的交叉熵損失;第二部分為Student與Teacher的softmax輸出的交叉熵損失。
3. 知識蒸餾在NLP/CV中的應用
下面給出這兩種蒸餾方式在自然語言處理和計算機視覺方面的一些頂會論文,方便大家擴展閱讀。
3.1 目標蒸餾-Logits方法應用
-
《Distilling the Knowledge in a Neural Network 》,NIPS,2014。
-
《Deep Mutual Learning》,CVPR,2018。
-
《Born Again Neural Networks》,CVPR,2018。
-
《Distilling Task-Specific Knowledge from BERT into Simple Neural Networks》,2019。
3.2 特征蒸餾方法應用
-
《FitNets: Hints for Thin Deep Nets》,ICLR,2015。
-
《Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer》, ICLR,2017。
-
《A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning》,CVPR,2017。
-
《Learning Efficient Object Detection Models》,NIPS,2017。
深度學習中的知識蒸餾技術(上) - 知乎 (zhihu.com)
深度學習中的知識蒸餾技術(下)-知識蒸餾與推薦系統 - 知乎 (zhihu.com)
總結
以上是生活随笔為你收集整理的[深度学习]知识蒸馏技术的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: QEMU固件模拟技术-stm32仿真分析
- 下一篇: AMD A620 低价主板要来了,华擎已