元学习Meta learning深入理解
目錄
基本理解
元學習與傳統的機器學習不同在哪里?
基本思想
MAML
MAML與pre-training有什么區別呢?
1. 損失函數不同
?2. 優化思想不同
MAML的優點及特點
MAML工作機理
?MAML應用:Toy Example
Reptile
基本理解
Meta Learning,翻譯為元學習,也可以認為是learn to learn。
元學習與傳統的機器學習不同在哪里?
知乎博主“南有喬木”在理解 元學習與傳統的機器學習 這里舉了個通俗易懂的例子,拿來給大家分享:
把訓練算法類比成學生在學校的學習,傳統的機器學習任務對應的是在每個科目上分別訓練一個模型,而元學習是提高學生整體的學習能力,學會學習。
學校中 ,有的學生各科成績都好,有的學生卻存在偏科現象。
- 各科成績都好,說明學生“元學習”能力強,學會了如何學習,可以迅速適應不同科目的學習任務。
- 偏科學生“元學習”能力相對較弱,只能某一科學習成績好,換門科就不行了。不會舉一反三,觸類旁通。
現在經常使用的深度神經網絡都是“偏科生”,分類和回歸對應的網絡模型完全不同,即使同樣是分類任務,把人臉識別的網絡架構用在分類ImageNet數據上,也未必能達到很高的準確率。
?
還有一個不同點:
- 傳統的深度學習方法都是從頭開始學習(訓練),即learning from scratch,對算力和時間都是更大的消耗和考驗。
- 元學習強調從不同的若干小任務小樣本來學習一個對未知樣本未知類別都有好的判別和泛化能力的模型
基本思想
寫在前面:圖片均來自李宏毅老師教學視頻
圖1對圖1的解釋:
Meta learning又稱為learn to learn,是說讓機器“學會學習”,擁有學習的能力。
元學習的訓練樣本和測試樣本都是基于任務的。通過不同類型的任務訓練模型,更新模型參數,掌握學習技巧,然后舉一反三,更好地學習其他的任務。比如任務1是語音識別,任務2是 圖像識別,···,任務100是文本分類,任務101與 前面100個任務類型均不同,訓練任務即為這100個不相同的任務,測試任務為第101個任務。
圖2對圖2的解釋:
在機器學習中,訓練樣本中的訓練集稱為train?set,測試集稱為test set。元學習廣泛應用于小樣本學習中,在元學習中,訓練樣本中的訓練集稱為support set,訓練樣本中的測試集叫做query set。
注意 :在機器學習中,只有一個大樣本數據集,將這個一個大數據集分成了兩部分,稱為train set和test set;
但是在元學習中,不止一個數據集,有多少個不同的任務,就有多少個數據集,然后每個數據集又分成兩部分,分別稱為support set和query set。
這里沒有考慮驗證集。
圖3?對圖3的解釋:
?
圖3為傳統深度學習的操作方式,即:
元學習與傳統深度學習的聯系在哪里?
圖3中紅色方框中的東西都是人為設計定義的,即我們常說的“超參數”,而元學習的目標就是去自動學習或者說代替方框中的東西,不同的代替方式就發明出不同的元學習算法。
圖4對圖4的解釋:
圖4簡單介紹了元學習的原理。
在神經網絡算法,都需定義一個損失函數來評價模型好壞,元學習的損失通過N個任務的測試損失相加得到。定義在第n個任務上的測試損失是?,則對于N個任務來說,總的損失為?,這就是元學習的優化目標。
假設有兩個任務,Task1和Task2,通過訓練任務1,得到任務1的損失函數l1,通過訓練任務2,得到任務2的損失函數l2,然后將這兩個任務的損失函數相加,得到整個訓練任務的損失函數,即圖4右上角的公式。
?
如果前文對元學習了解還不夠,后面有更詳細的解釋:
Meta Learning 的算法有很多,有些高大上的算法可以針對不同的訓練任務,輸出不同的神經網絡結構和超參數,例如 Neural Architecture Search (NAS) 和 AutoML。這些算法大多都相當復雜,我們普通人難以實現。另外一種比較容易實現的 Meta Learning 算法,就是本文要介紹的 MAML 和 Reptile,它們不改變深度神經網絡的結構,只改變網絡的初始化參數。
?
MAML
理解MAML算法的損失函數含義和推導過程,首先得與pre-training區分開來。
對圖5的解釋:
我們定義初始化參數為,其初始化參數為,定義在第n個測試任務上訓練之后的模型參數為,于是MAML總的損失函數為?。
圖5MAML與pre-training有什么區別呢?
1. 損失函數不同
MAML的損失函數為?。
pre-training的損失函數是。
直觀上理解是MAML所評測的損失是在任務訓練之后的測試loss,而pre-training是直接在原有基礎上求損失沒有經過訓練。如圖6所示。
圖6??????2. 優化思想不同
這里先分享一下我看到的對損失函數最恰當的描述:(https://zhuanlan.zhihu.com/p/72920138)
損失函數的奧妙:初始化參數掌控全場,分任務參數各自為營
圖7?
圖8?如圖7和圖8所示:
上圖中橫坐標代表網絡參數,縱坐標代表損失函數。淺綠和墨綠兩條曲線代表兩個 task 的損失函數隨參數變化曲線。
假設模型參數的和向量都是一維的,
model pre-training的初衷是尋找一個從一開始就讓所有任務的損失之和處于最小狀態,它并不保證所有任務都能訓練到最好的,如上圖所示,??即收斂到局部最優。從圖7中看就是,loss值按照計算公式達到了最小值,但此時task2(淺綠)線只能收斂到左邊的綠點處,即局部最小處,而從整體看來,全局最小處在的右邊出現。
而MAML的初衷是找到一個不偏不倚的,使得不管是在任務1的loss曲線還是任務2的loss曲線上,都能下降到分別的全局最優。從圖8中看就是,loss值按照計算公式到達了最小值??,此時,task1可以收斂到左邊綠點處,task2可以收斂到右邊綠點處,二者均為全局最小值。
李宏毅老師在這里舉了個很生動的比喻:他把MAML比作選擇讀博,即更在意的是學生的以后的發展潛力;而model pre-training就相當于選擇畢業直接去大廠工作,馬上就把所學技能兌現金錢,在意的是當下表現如何。如圖9所示。
圖9MAML的優點及特點
如圖10所示:MAML
MAML工作機理
?在介紹MAML的論文中,給出的算法如圖11所示:
圖11?下面給出每步的詳細解釋:參考(https://zhuanlan.zhihu.com/p/57864886)
- Require1:task的分布,即隨機抽取若干個task組成任務池
- Require2:step size是學習率,MAML基于二重梯度,每次迭代包括兩次參數更新的過程,所以有兩個學習率可以調整。
有一個對MAML過程更直觀的圖:
圖12對圖12的解釋為:
?MAML應用:Toy Example
該 toy example 的目標是擬合正弦曲線:??,其中 a、b 都是隨機數,每一組 a、b 對應一條正弦曲線,從該正弦曲線采樣 K 個點,用它們的橫縱坐標作為一組 task,橫坐標為神經網絡的輸入,縱坐標為神經網絡的輸出。
我們希望通過在很多 task 上的學習,學到一組神經網絡的初始化參數,再輸入測試 task 的 K 個點時,經過快速學習,神經網絡能夠擬合測試 task 對應的正弦曲線。
圖13左側是用常規的 fine-tune 算法初始化神經網絡參數。我們觀察發現,當把所有訓練 task 的損失函數之和作為總損失函數,來直接更新網絡參數,會導致無論測試 task 輸入什么坐標,預測的曲線始終是 0 附近的曲線,因為 a 和 b 可以任意設置,所以所有可能的正弦函數加起來,它們的期望值為 0,因此為了獲得所有訓練 task 損失函數之和的 global minima,不論什么輸入坐標,神經網絡都將輸出 0。
右側是通過 MAML 訓練的網絡,MAML的初始化結果是綠色的線,和橘黃色的線有差異。但是隨著finetuning的進行,結果與橘黃色的線更加接近。
?
針對前面介紹的MAML,提出一個問題:
在更新訓練任務的網絡時,只走了一步,然后更新meta網絡。為什么是一步,可以是多步嗎?
李宏毅老師的課程中提到:
-
只更新一次,速度比較快;因為meta learning中,子任務有很多,都更新很多次,訓練時間比較久。
-
MAML希望得到的初始化參數在新的任務中finetuning的時候效果好。如果只更新一次,就可以在新任務上獲取很好的表現。把這件事情當成目標,可以使得meta網絡參數訓練是很好(目標與需求一致)。
-
當初始化參數應用到具體的任務中時,也可以finetuning很多次。
-
Few-shot learning往往數據較少。
Reptile
Reptile與MAML類似,其算法圖如下:
圖14Reptile 中,每更新一次??,需要 sample 一個 batch 的 task(圖中 batchsize=1),并在各個 task 上施加多次梯度下降,得到各個 task 對應的??。然后計算??和主任務的參數的差向量,作為更新??的方向。這樣反復迭代,最終得到全局的初始化參數。
?其偽代碼如下:
Reptile,每次sample出1個訓練任務?
Reptile,每次sample出1個batch訓練任務?
在Reptile中:
-
訓練任務的網絡可以更新多次
-
reptile不再像MAML一樣計算梯度(因此帶來了工程性能的提升),而是直接用一個參數??乘以meta網絡與訓練任務的網絡參數的差來更新meta網絡參數
-
從效果上來看,Reptile效果與MAML基本持平
?
以上為對元學習的深入理解,后續可能出MAML數學公式推導,感興趣的讀者留言~
參考資料
【1】https://zhuanlan.zhihu.com/p/72920138
【2】https://zhuanlan.zhihu.com/p/57864886
【3】https://zhuanlan.zhihu.com/p/108503451
【4】MAML論文https://arxiv.org/pdf/1703.03400.pdf
【5】?https://zhuanlan.zhihu.com/p/136975128
總結
以上是生活随笔為你收集整理的元学习Meta learning深入理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LSTM详解
- 下一篇: 《疯狂JAVA讲义》笔记1