论文笔记:Distilling the Knowledge
生活随笔
收集整理的這篇文章主要介紹了
论文笔记:Distilling the Knowledge
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
原文:Distilling the Knowledge in a Neural Network
Distilling the Knowledge
1、四個問題
要解決什么問題?
-
神經網絡壓縮。
-
我們都知道,要提高模型的性能,我們可以使用ensemble的方法,即訓練多個不同的模型,最后將他們的結果進行融合。像這樣使用ensemble,是最簡單的能提高模型性能的方法,像kaggle之類的比賽中大家也都是這么做的。但是使用多個深度網絡模型進行ensemble常常會有十分巨大的計算開銷,而且算起來整個算法中的模型也會非常大。
-
這篇文章就是為了解決這個問題,提出一套方法,將一個ensemble中的知識信息壓縮到一個小模型中,讓最終的模型更易于部署,節約算力。
用了什么方法解決?
- 提出一種 **知識蒸餾(Knowledge Distillation)**方法,從大模型所學習到的知識中學習有用信息來訓練小模型,在保證性能差不多的情況下進行模型壓縮。
- 知識蒸餾:
- 將一個訓練好的大模型的知識通過遷移學習手段遷移給小模型。
- 換成大白話說就是:先訓練一個復雜的網絡,然后再使用這個網絡來輔助訓練一個小網絡;相當于大網絡是老師,小網絡是學生。
- 知識蒸餾:
- 提出一種新的**集成模型(Ensembles of Models)**方法,包括一個通用模型(Generalist Model)和多個專用模型(Specialist Models),其中,專用模型用來對那些通用模型無法區分的細粒度(Fine-grained)類別的圖像進行區分。
效果如何?
- 文中給了一些語音識別等領域的實驗,通過知識蒸餾訓練得到的小模型的準確率雖不及大模型,但能比較接近了,效果也比較不錯。
還存在什么問題?
- 雖然思想很簡單,實際使用時還需要較多的trick,很難保證每次訓練都能取得很好的效果。
2、論文概述
2.1、知識蒸餾(Knowledge Distillation)
- 對于多分類任務的大模型來說,我們通常都是希望最大化輸出概率的log值,來獲取準確的預測結果。然而,這么做會有一個副作用,就是會賦予所有的非正確答案一定的概率,即使這些概率值都很小。
- 這些非正確答案的相對概率,隱含了大模型在訓練中傾向于如何泛化的信息。
- 為此,Hinton等人在論文中提出了soft target的概念。
- **soft target **指的是大模型最后輸出的概率預測值,即softmax層的輸出。
- 相對的,就有hard target,指的是數據集的標簽了。
- 給softmax加入蒸餾的概念:
- 公式中的TTT表示“溫度”。使用它的目的就是,讓輸出的softmax更平滑,分布更均勻,同時保證各項之間的大小相對關系不變。
- T=1T = 1T=1時,上式就是平常使用的softmax函數。
- 上圖摘自知乎,根據圖很方便理解。
- 當我們不知道數據集的label時,可以直接使用大模型的soft target來訓練小模型。訓練的損失函數是cross-entropy。
- 當我們已知數據集的label時,可以將soft target和小模型預測值的cross-entropy與hard target和小模型預測值得cross-entropy進行加權求合,權重參數為λ\lambdaλ。這種方法得到的效果會更好。
- 實現流程:
2.2、distillation在特殊情況下相當于匹配logits(Matching logits is a special case of distillation )
- 再重復一遍,知識蒸餾(Knowledge Distillation)在特殊情況下相當于logits。
- 下式是cross-entropy到logit的梯度計算公式:
- viv_ivi?是大模型(cumbersome model)產生的logits,ziz_izi?是小模型(distillation model)的logits。
- 如果溫度很高,即TTT很大,那么:
- 假設logits的均值均為0,即:∑jzj/T=∑jvj/T=0\sum_j{z_j / T} = \sum_j{v_j / T} = 0∑j?zj?/T=∑j?vj?/T=0。則有:
- 所以,在TTT很大且logit均值為0時,知識蒸餾就相當于標簽匹配:12(zi?vi)2\frac{1}{2} (z_i - v_i)^221?(zi??vi?)2。
- 當T較小時,蒸餾更加關注負標簽,在訓練復雜網絡的時候,這些負標簽是幾乎沒有約束的,這使得產生的負標簽概率是噪聲比較大的,所以采用更大的T值(上面的簡化方法)是有優勢的。而另一方面,這些負標簽概率也是包含一定的有用信息的,能夠用于幫助簡單網絡的學習。這兩種效應哪種占據主導,是一個實踐問題。(注:我也不是很理解論文中的這段,簡單點說,T就是個超參數,需要你自己調。)
2.3、在大數據集上訓練專家模型(Training ensembles of specialists on very big datasets )
- 當數據集非常巨大以及模型非常復雜時,訓練多個模型所需要的資源是難以想象的,因此作者提出了一種新的集成模型(ensemble)方法:
- 一個generalist model:使用全部數據訓練。
- 多個specialist model(專家模型):對某些容易混淆的類別進行訓練。
- specialist model的訓練集中,一半是由訓練集中包含某些特定類別的子集(special subset)組成,剩下一半是從剩余數據集中隨機選取的。
- 這個ensemble的方法中,只有generalist model是使用完整數據集訓練的,時間較長,而剩余的所有specialist model由于訓練數據相對較少,且相互獨立,可以并行訓練,因此訓練模型的總時間可以節約很多。
- specialist model由于只使用特定類別的數據進行訓練,因此模型對別的類別的判斷能力幾乎為0,導致非常容易過擬合。
- 解決辦法:當 specialist model 通過 hard targets 訓練完成后,再使用由 generalist model 生成的 soft targets 進行微調。這樣做是因為 soft targets 保留了一些對于其他類別數據的信息,因此模型可以在原來基礎上學到更多知識,有效避免了過擬合。
- 實現流程:
3、參考資料
總結
以上是生活随笔為你收集整理的论文笔记:Distilling the Knowledge的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 论文笔记:DeepID2
- 下一篇: 论文笔记:Git Loss