知识蒸馏概述
知識蒸餾(knowledge distillation)是模型壓縮的一種常用的方法,不同于模型壓縮中的剪枝和量化,知識蒸餾是通過構建一個輕量化的小模型,利用性能更好的大模型的監督信息,來訓練這個小模型,以期達到更好的性能和精度。最早是由Hinton在2015年首次提出(Distilling the Knowledge in a Neural Network)并應用在分類任務上面,這個大模型稱之為Teacher(教師模型),小模型稱之為Student(學生模型)。來自Teacher模型輸出的監督信息稱之為knowledge(知識),而student學習遷移來自teacher的監督信息的過程稱之為Distillation(蒸餾)。
目前知識蒸餾的算法已經廣泛應用到圖像語義識別,目標檢測等場景中,并且針對不同的研究場景,蒸餾方法都做了部分的定制化修改,同時,在行人檢測,人臉識別,姿態檢測,圖像域遷移,視頻檢測等方面,知識蒸餾也是作為一種提升模型性能和精度的重要方法,隨著深度學習的發展,這種技術也會更加的成熟和穩定。
注:Hinton開篇指出,所提方法是為了'壓縮模型',KD能夠讓Student model獲取Teacher model的泛化能力,也即讓小模型能夠干大事情。但KD仍然要訓練Teacher model,并且Student model需要依靠Teacher model得到的soft targets,意味著Teacher model是不可或缺,那這樣的話,'壓縮模型'目的是否真的達到了呢?
壓縮模型的對象一般是指的部署上線的模型,而Teacher model只在訓練的過程中用到。并且同一個Teacher model可以用于蒸餾多個student model。
1.?知識蒸餾作用
①提升模型精度
用戶如果對目前的網絡模型A的精度不是很滿意,那么可以先訓練一個更高精度的teacher模型B(通常參數量更多,時延更大),然后用這個訓練好的teacher模型B對student模型A進行知識蒸餾,得到一個更高精度的模型。
②降低模型時延,壓縮網絡參數
用戶如果對目前的網絡模型A的時延不滿意,可以先找到一個時延更低,參數量更小的模型B,通常來講,這種模型精度也會比較低,然后通過訓練一個更高精度的teacher模型C來對這個參數量小的模型B進行知識蒸餾,使得該模型B的精度接近最原始的模型A,從而達到降低時延的目的。
③圖片標簽之間的域遷移
用戶使用狗和貓的數據集訓練了一個teacher模型A,使用香蕉和蘋果訓練了一個teacher模型B,那么就可以用這兩個模型同時蒸餾出一個可以識別狗,貓,香蕉以及蘋果的模型,將兩個不同與的數據集進行集成和遷移。
④降低標注量
可以通過半監督的蒸餾方式來實現,用戶利用訓練好的teacher網絡模型來對未標注的數據集進行蒸餾,達到降低標注量的目的。
2.?知識蒸餾原理
KD的訓練過程和傳統的訓練過程的對比:
- 傳統training過程(hard targets): 對ground truth求極大似然
 - KD的training過程(soft targets): 用large model的class probabilities作為soft targets
 
?softmax層的輸出,除了正例之外,負標簽也帶有大量的信息,比如某些負標簽對應的概率遠遠大于其他負標簽。而在傳統的訓練過程(hard target)中,所有負標簽都被統一對待。也就是說,KD的訓練方式使得每個樣本給Net-S帶來的信息量大于傳統的訓練方式。
例如,在手寫體數字識別任務MNIST中,假設某個輸入的“2”更加形似"3",softmax的輸出值中"3"對應的概率為0.1,而其他負標簽對應的值都很小,而另一個"2"更加形似"7","7"對應的概率為0.1。這兩個"2"對應的hard target的值是相同的,但是它們的soft target卻是不同的,由此我們可見soft target蘊含著比hard target多的信息。并且soft target分布的熵相對高時,其soft target蘊含的知識就更豐富。
?所以,通過蒸餾的方法訓練出的Net-S相比使用完全相同的模型結構和訓練數據只使用hard target的訓練方法得到的模型擁有更好的泛化能力。
softmax函數加了溫度這個變量:
?原始的softmax函數是T = 1時的特例, T <?1時,概率分布比原始更“陡峭”, T > 1時,概率分布比原始更“平緩”。溫度越高,softmax上各個值的分布就越平均。
溫度的高低改變的是Net-S訓練過程中對負標簽的關注程度: 溫度較低時,對負標簽的關注,尤其是那些顯著低于平均值的負標簽的關注較少;而溫度較高時,負標簽相關的值會相對增大,Net-S會相對多地關注到負標簽。
知識蒸餾第一步是訓練Net-T;第二步是在高溫T下,蒸餾Net-T的知識到Net-S。
主要是第二步:高溫蒸餾的過程
?目標函數由distill loss(對應soft target)和student loss(對應hard target)加權得到。
①Net-S在相同溫度T條件下的softmax輸出和soft target的cross entropy就是Loss函數的第一部:
??②Net-S在T=1的條件下的softmax輸出和ground truth的cross entropy就是Loss函數的第二部分:
總結
                            
                        - 上一篇: 荣耀waterplay鸿蒙,对比发现荣耀
 - 下一篇: 希腊罗马神话传说和《圣经》中的英语成语典