3.2)深度学习笔记:机器学习策略(2)
目錄
1)Carrying out error analysis
2)Cleaning up Incorrectly labeled data
3)Build your first system quickly then iterate
4)Training and testing on different distributios
5)Bias and Variance with mismatched data distributions
6)Addressing data mismatch
7)Transfer learning
8)Multi-task learning
9)What is end-to-end deep learning
10)Whether to use end-to-end deep learning
以下筆記是吳恩達老師深度學習課程第三門課第二周的的學習筆記:ML Strategy。筆記參考了黃海廣博士的內容,在此表示感謝。
1)Carrying out error analysis
通過人工檢查機器學習模型得出的結果中出現的一些錯誤,有助于深入了解下一步要進行的工作。這個過程被稱作錯誤分析(Error Analysis)。
例如,你可能會發現一個貓圖片識別器錯誤地將一些看上去像貓的狗誤識別為貓。這時,立即盲目地去研究一個能夠精確識別出狗的算法不一定是最好的選擇,因為我們不知道這樣做會對提高分類器的準確率有多大的幫助。
這時,我們可以從分類錯誤的樣本中統計出狗的樣本數量。根據狗樣本所占的比重來判斷這一問題的重要性。假如狗類樣本所占比重僅為 5%,那么即使花費幾個月的時間來提升模型對狗的識別率,改進后的模型錯誤率并沒有顯著改善;而如果錯誤樣本中狗類所占比重為 50%,那么改進后的模型性能會有較大的提升。因此,花費更多的時間去研究能夠精確識別出狗的算法是值得的。
這種人工檢查看似簡單而愚笨,但卻是十分必要的,因為這項工作能夠有效避免花費大量的時間與精力去做一些對提高模型性能收效甚微的工作,讓我們專注于解決影響模型準確率的主要問題。
在對輸出結果中分類錯誤的樣本進行人工分析時,可以建立一個表格來記錄每一個分類錯誤的具體信息,例如某些圖像是模糊的,或者是把狗識別成了貓等,并統計屬于不同錯誤類型的錯誤數量。這樣,分類結果會更加清晰。
總結一下,進行錯誤分析時,你應該觀察錯誤標記的例子,看看假陽性和假陰性,統計屬于不同錯誤類型的錯誤數量。在這個過程中,你可能會得到啟發,歸納出新的錯誤類型。總之,通過統計不同錯誤標記類型占總數的百分比,有助于發現哪些問題亟待解決,或者提供構思新優化方向的靈感。
2)Cleaning up Incorrectly labeled data
在監督式學習中,訓練樣本有時候會出現輸出Y標注錯誤的情況,即incorrectly labeled examples。如果這些label標錯的情況是隨機性的,DL算法對其包容性是比較強的,即健壯性好,一般可以直接忽略,無需修復。然而,如果是系統錯誤(systematic errors),這將對DL算法造成影響,降低模型性能。
剛才說的是訓練樣本中出現incorrectly labeled data,如果是dev/test sets中出現incorrectly labeled data,該怎么辦呢?
方法很簡單,利用上節內容介紹的error analysis,統計dev sets中所有分類錯誤的樣本中incorrectly labeled data所占的比例。根據該比例的大小,決定是否需要修正所有incorrectly labeled data,還是可以忽略。舉例說明,若:
- ??? Overall dev set error: 10%
- ??? Errors due incorrect labels: 0.6%
- ??? Errors due to other causes: 9.4%
上面數據表明Errors due incorrect labels所占的比例僅為0.6%,占dev set error的6%,而其它類型錯誤占dev set error的94%,即錯誤標標注標簽所占比例較低,這種情況下,可以忽略incorrectly labeled data。
如果優化DL算法后,出現下面這種情況:
- ??? Overall dev set error: 2%
- ??? Errors due incorrect labels: 0.6%
- ??? Errors due to other causes: 1.4%
上面數據表明Errors due incorrect labels所占的比例依然為0.6%,但是卻占dev set error的30%,而其它類型錯誤占dev set error的70%。因此,這種情況下,incorrectly labeled data不可忽略,需要手動修正。
我們知道,dev set的主要作用是在不同算法之間進行比較,選擇錯誤率最小的算法模型。但是,如果有incorrectly labeled data的存在,當不同算法錯誤率比較接近的時候,我們無法僅僅根據Overall dev set error準確指出哪個算法模型更好,必須修正incorrectly labeled data。
關于修正incorrect dev/test set data,有幾條建議:
- ??? Apply same process to your dev and test sets to make sure they continue to come from the same distribution
- ??? Consider examining examples your algorithm got right as well as ones it got wrong
- ??? Train and dev/test data may now come from slightly different distributions
3)Build your first system quickly then iterate
對于每個可以改善模型的合理方向,如何選擇一個方向集中精力處理成了問題。如果想搭建一個全新的機器學習系統,建議根據以下步驟快速搭建好第一個系統,然后開始迭代:
-
設置好訓練、驗證、測試集及衡量指標,確定目標;
-
快速訓練出一個初步的系統,用訓練集來擬合參數,用驗證集調參,用測試集評估;
-
通過偏差/方差分析以及錯誤分析等方法,決定下一步優先處理的方向。
4)Training and testing on different distributios
有時,我們很難得到來自同一個分布的訓練集和驗證/測試集。還是以貓識別作為例子,我們的訓練集可能由網絡爬蟲得到,圖片比較清晰,而且規模較大(例如 20 萬);而驗證/測試集可能來自用戶手機拍攝,圖片比較模糊,且數量較少(例如 1 萬),難以滿足作為訓練集時的規模需要。
雖然驗證/測試集的質量不高,但是機器學習模型最終主要應用于識別這些用戶上傳的模糊圖片。考慮到這一點,在劃分數據集時,可以將 20 萬張網絡爬取的圖片和 5000 張用戶上傳的圖片作為訓練集,而將剩下的 5000 張圖片一半作驗證集,一半作測試集。比起混合數據集所有樣本再隨機劃分,這種分配方法雖然使訓練集分布和驗證/測試集的分布并不一樣,但是能保證驗證/測試集更接近實際應用場景,在長期能帶來更好的系統性能。
5)Bias and Variance with mismatched data distributions
之前的學習中,我們通過比較人類水平誤差、訓練集錯誤率、驗證集錯誤率的相對差值來判斷進行偏差/方差分析。但在訓練集和驗證/測試集分布不一致的情況下,無法根據相對差值來進行偏差/方差分析。這是因為訓練集錯誤率和驗證集錯誤率的差值可能來自于算法本身(歸為方差),也可能來自于樣本分布不同,和模型關系不大。
在可能存在訓練集和驗證/測試集分布不一致的情況下,為了解決這個問題,我們可以再定義一個訓練-驗證集(Training-dev Set)。訓練-驗證集和訓練集的分布相同(或者是訓練集分割出的子集),但是不參與訓練過程。
現在,我們有了訓練集錯誤率、訓練-驗證集錯誤率,以及驗證集錯誤率。其中,訓練集錯誤率和訓練-驗證集錯誤率的差值反映了方差;而訓練-驗證集錯誤率和驗證集錯誤率的差值反映了樣本分布不一致的問題,從而說明模型擅長處理的數據和我們關心的數據來自不同的分布,我們稱之為數據不匹配(Data Mismatch)問題。
人類水平誤差、訓練集錯誤率、訓練-驗證集錯誤率、驗證集錯誤率、測試集錯誤率之間的差值所反映的問題如下圖所示:
一般情況下,human-level error、training error、training-dev error、dev error以及test error的數值是遞增的,但是也會出現dev error和test error下降的情況。這主要可能是因為訓練樣本比驗證/測試樣本更加復雜,難以訓練。
6)Addressing data mismatch
這里有兩條關于如何解決數據不匹配問題的建議:
- 做誤差分析,嘗試了解訓練集和驗證/測試集的具體差異(主要是人工查看訓練集和驗證集的樣本);
- 嘗試將訓練數據調整得更像驗證集,或者收集更多類似于驗證/測試集的數據。
如果你打算將訓練數據調整得更像驗證集,可以使用的一種技術是人工合成數據。我們以語音識別問題為例,實際應用場合(驗證/測試集)是包含背景噪聲的,而作為訓練樣本的音頻很可能是清晰而沒有背景噪聲的。為了讓訓練集與驗證/測試集分布一致,我們可以給訓練集人工添加背景噪聲,合成類似實際場景的聲音。
人工合成數據能夠使數據集匹配,從而提升模型的效果。但需要注意的是,不能給每段語音都增加同一段背景噪聲,因為這樣模型會對這段背景噪音出現過擬合現象,使得效果不佳。
7)Transfer learning
遷移學習(Tranfer Learning)是通過將已訓練好的神經網絡模型的一部分網絡結構應用到另一模型,將一個神經網絡從某個任務中學到的知識和經驗運用到另一個任務中,以顯著提高學習任務的性能。
例如,我們將為貓識別器構建的神經網絡遷移應用到放射科診斷中。因為貓識別器的神經網絡已經學習到了有關圖像的結構和性質等方面的知識,所以只要先刪除神經網絡中原有的輸出層,加入新的輸出層并隨機初始化權重系數(,),隨后用新的訓練集進行訓練,就完成了以上的遷移學習。
如果新的數據集很小,可能只需要重新訓練輸出層前的最后一層的權重,并保持其他參數不變;而如果有足夠多的數據,可以只保留網絡結構,重新訓練神經網絡中所有層的系數。這時初始權重由之前的模型訓練得到,這個過程稱為預訓練(Pre-Training),之后的權重更新過程稱為微調(Fine-Tuning)。
你也可以不止加入一個新的輸出層,而是多向神經網絡加幾個新層。
在下述場合進行遷移學習是有意義的:
-
兩個任務有同樣的輸入(比如都是圖像或者都是音頻);
-
擁有更多數據的任務遷移到數據較少的任務;
-
某一任務的低層次特征(底層神經網絡的某些功能)對另一個任務的學習有幫助。
8)Multi-task learning
遷移學習中的步驟是串行的;而多任務學習(Multi-Task Learning)使用單個神經網絡模型,利用共享表示采用并行訓練同時學習多個任務。多任務學習的基本假設是多個任務之間具有相關性,并且任務之間可以利用相關性相互促進。例如,屬性分類中,抹口紅和戴耳環有一定的相關性,單獨訓練的時候是無法利用這些信息,多任務學習則可以利用任務相關性聯合提高多個屬性分類的精度。
以汽車自動駕駛為例,需要實現的多任務是識別行人、車輛、交通標志和信號燈。如果在輸入的圖像中檢測出車輛和交通標志,則輸出的 y 為:
多任務學習模型的成本函數為:
其中,j 代表任務下標,有 c 個任務。對應的損失函數為:? ? ??
多任務學習是使用單個神經網絡模型來實現多個任務。實際上,也可以分別構建多個神經網絡來實現。多任務學習中可能存在訓練樣本 Y 某些標簽空白的情況,這不會影響多任務學習模型的訓練。
多任務學習和 Softmax 回歸看上去有些類似,容易混淆。它們的區別是,Softmax 回歸的輸出向量 y 中只有一個元素為 1;而多任務學習的輸出向量 y 中可以有多個元素為 1。
在下述場合進行多任務學習是有意義的:
-
訓練的一組任務可以共用低層次特征;
-
通常,每個任務的數據量接近;
-
能夠訓練一個足夠大的神經網絡,以同時做好所有的工作。多任務學習會降低性能的唯一情況(即和為每個任務訓練單個神經網絡相比性能更低的情況)是神經網絡還不夠大。
在多任務深度網絡中,低層次信息的共享有助于減少計算量,同時共享表示層可以使得幾個有共性的任務更好的結合相關性信息,任務特定層則可以單獨建模任務特定的信息,實現共享信息和任務特定信息的統一。
在實踐中,多任務學習的使用頻率要遠低于遷移學習。計算機視覺領域中的物體識別是一個多任務學習的例子。
9)What is end-to-end deep learning
在傳統的機器學習分塊模型中,每一個模塊處理一種輸入,然后其輸出作為下一個模塊的輸入,構成一條流水線。而端到端深度學習(End-to-end Deep Learning)只用一個單一的神經網絡模型來實現所有的功能。它將所有模塊混合在一起,只關心輸入和輸出。
如果數據量較少,傳統機器學習分塊模型所構成的流水線效果會很不錯。但如果訓練樣本足夠大,并且訓練出的神經網絡模型足夠復雜,那么端到端深度學習模型的性能會比傳統機器學習分塊模型更好。
而如果數據集規模適中,還是可以使用流水線方法,但是可以混合端到端深度學習,通過神經網絡繞過某些模塊,直接輸出某些特征。
10)Whether to use end-to-end deep learning
應用端到端學習的優點:
-
只要有足夠多的數據,剩下的全部交給一個足夠大的神經網絡。比起傳統的機器學習分塊模型,可能更能捕獲數據中的任何統計信息,而不需要用人類固有的認知(或者說,成見)來進行分析;
-
所需手工設計的組件更少,簡化設計工作流程;
缺點:
-
需要大量的數據;
-
排除了可能有用的人工設計組件;
根據以上分析,決定一個問題是否應用端到端學習的關鍵點是:是否有足夠的數據,支持能夠直接學習從 x 映射到 y 并且足夠復雜的函數?
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的3.2)深度学习笔记:机器学习策略(2)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Linux系统查看开放的端口、开启指定端
- 下一篇: 风险承受能力测试是什么?风险承受能力测试