如何跳出魔改模型?
?PaperWeekly 原創 · 作者|燕皖
單位|淵亭科技
研究方向|計算機視覺、CNN
剛剛和小伙伴參加完 kaggle 的 Global Wheat Detection 比賽獲得了 Private Leaderboard 第七的名次,首先,在這次比賽中我們發現在 Public Leaderboard 所得到成績和 Private Leaderboard 所得到的成績有很大的差異,其次,我們還發現了一些除魔改模型之外對漲點有效的方法。這是我們成績排名截圖。下面就具體看看這兩種方法。
Data argument
在訓練神經網絡時,我們常常會遇到的一個只有小幾百數據,然而,神經網絡模型都需要至少成千上萬的圖片數據。因此,為了獲得更多的數據,我們只要對現有的數據集進行微小的改變。
比如翻轉(flips)、平移(translations)、旋轉(rotations)等等。而我們要介紹的是 MixMatch,可以看做是半監督學習下的 mixup 擴增。
論文標題:MixMatch: A Holistic Approach to Semi-Supervised Learning
論文鏈接:https://arxiv.org/pdf/1905.02249.pdf
代碼鏈接:https://github.com/google-research/mixmatch
對于許多半監督學習方法,往往都是增加了一個損失項,這個損失項是在未標記的數據上計算的,以促進模型更好地泛化到訓練集之外的數據中。一般地,這個損失項可分為三類:
熵最小化——它鼓勵模型對未標記的數據輸出有信心的預測;
一致性正則化——當模型的輸入受到擾動時,它鼓勵模型產生相同的輸出分布;
泛型正則化——這有助于模型很好地泛化,避免對訓練數據的過度擬合。
MixMatch 整合了前面提到的一些 ideas 。對于給定一個已經標簽的 batch X 和同樣大小未標簽的 batch U,先生成一批經過 Mixup 處理的增強標簽數據 X' 和一批偽標簽的 U',然后分別計算帶標簽數據和未標簽數據的損失項。具體地流程如下:
將有標簽數據 X 和無標簽數據U混合在一起形成一個混合數據 W,然后有標簽數據 X 和 W 中的前 X 個進行 mixup 后,得到的數據作為有標簽數據 X'?,同樣,無標簽數據和 W 中的后 U個進行 mixup 后,得到的數據作為無標簽數據 U'。
損失函數:對于有標簽的數據,使用交叉熵;“guess”標簽的數據使用 MSE;然后將兩者加權組合。如下:
MixMatch 就是將無監督和有監督的數據分開進行 mixup 增強,然后無監督的 loss 使用的是 MSE。在比賽中,我們發現如果有監督和無監督一起進行 mixup,性能會下降,而分開進行 mixup 增強,則會進一步提升。
Semi-Supervised Learning
盡管 SSL 取得顯著進展,但 SSL 方法主要應用于圖像分類,今天介紹一種用于目標檢測的 SSL,稱為 STAC。
論文標題:A Simple Semi-Supervised Learning Framework for Object Detection
論文鏈接:https://arxiv.org/pdf/2005.04757.pdf
代碼鏈接:https://github.com/google-research/ssl_detection/
這篇文章利用了 Self-training和 Augmentation driven Consistency regularization,所以稱為 STAC。具體訓練步驟如下:
在可用的標簽圖像上訓練教師模型。
生成未標記圖像的偽標簽(即邊界框和他們的類別標簽)。
將?strong data augmentations?應用于未標記的圖像,并進行偽標簽的轉換。
計算無監督損失和有監督損失以訓練檢測器。
現在就看 SSL 的另一個關鍵點——未標記數據的無監督的損失函數:
其中,ls 是有監督的損失函數,lu 是無監督的損失函數,A 是應用于未標記圖像的強數據增強,p 和 s 是類別,t 和 q 是邊框坐標。
將 data augmentations 應用于半監督學習的方法在很早就有文獻提出,其背后的思想是 Consistency Regularization,即使對未標記的示例進行了增強,分類器也應該輸出相同的類分布。
具體地,一致性正則化強制未標記的樣本 x 應該與增強后的樣本 Augment(x) 保持一致,其中 Augment 是一個隨機數據增強函數,例如:隨機空間平移或添加噪聲。而本文實驗發現 λu ∈ [1, 2] 的時候效果最好。說明了半監督和有監督的重要性是不一樣的。
Global Wheat Detection
這里還是先介紹一下小麥頭檢測的比賽的內容:
比賽鏈接:
https://www.kaggle.com/c/global-wheat-detection/overview/code-requirements
比賽背景:主要是準確估計算出不同品種的小麥頭的密度和大小,從而幫助農民評估自己的農作物
比賽要求:檢測并框出圖片中的小麥頭,評估方式是 MAP,MAP,主要是權衡 precision 和 recall 的一個指標。截止時間 8 月 4 號,提交要求不能聯網并且 CPU Notebook <= 9 hours run-time,GPU Notebook <= 6 hours run-time
數據集:訓練集為 3434 張小麥圖片,在 Public Leaderboard 上計算成績的測試集占總的測試集的 62%,而在最終計算 Private Leaderboard 成績的測試集為占中的測試集的 38%。
3.1 賽題的難度
小麥外觀會因成熟度,顏色,基因型和頭部方向而異,因此對模型的泛化能力要求比較高。
一張圖片中小麥頭數量很多密度很大,因此常常出現小麥頭重疊的情況。
訓練數據少,并且部分圖片模糊和大小不一致。
3.2 解決思路
在訓練階段通過對圖像進行增強來數據擴增訓練出多個模型,然后在測試集上進行半監督學習。最后,在檢測時利用 TTA(Test time augmentation)增加檢測的準確性,并利用 wbf 融合多個模型的結果。
3.2.1 訓練數據擴增
由于在訓練的數據量較小(容易過擬合),而并且測試集的分布比較分散,對模型泛化能力要求比較高,因此采取對圖像增強的方式對訓練集進行擴增,采取的圖像擴增的方式有圖像的縮放、隨機水平翻轉和垂直翻轉、多個圖像的拼接、色彩空間 hsv 增強,通過這個方式訓練集擴增了 5 倍以此緩解訓練數據量小的問題。
3.2.2 半監督訓練
偽標簽對成績的提升有很大的幫助,最初在 Public Leaderboard 上沒加入偽標簽技術成績:0.7522 , 加入偽標簽技術后成績為:0.7720 ?,增加了 0.0198,排名提升了一百多名效果可以說是相當的明顯了。
具體地,我們對圖像的增強策略包括 Vertical Flip,HorizontalFlip,Rotate90,180,270,Multi-Scale 0.83 and 1.2 ,cutout,mixup,然后利用在訓練集訓練好的模型對未標記的測試集圖片進行偽標簽制作。
最開始,我們也僅僅是增加這些 argument,能夠達到?0.7720,進一步使用 MixMatch 和 STAC 的方法后,分別能夠達到 0.7734 和 0.7751。
_ | Acc | ||
Baseline | 0.7593 | ||
Rotate90,180,270 | 0.7640 | ||
Vertical/Horizontal Flip | 0.7682 | ||
Multi-Scale 0.83 and 1.2 | 0.7720 | ||
MixMatch | 0.7734 | ||
STAC(λu=1.4) | 0.7751 |
3.3.3 模型檢測過程
在檢測的過程中使用了 TTA(Test time augmentation),對原始圖像進行旋轉(90°,180°,270°)、垂直水平翻折、圖像縮放(放大 1.2 倍,縮小 0.87 倍),然后對 TTA 后的圖像進行檢測,最終將所得到的 box 進行 nms。
采用 TTA(測試時增強),可以對一幅小麥圖像做多種變換,創造出多個不同版本,對多個版本數據進行計算最后得到平均輸出作為最終結果,提高了結果的穩定性和精準度。
3.3 Private Leaderboard
在這次比賽中最終提交的兩個方案中,方案一也就是上面使用的方案取得了 Private Leaderboard 第七的成績,方案二:增加了根據驗證集計算成績自動選擇最好的閾值,對于偽標簽的訓練 epoch 增加到 15,而減少了半監督訓練中的 Argument(只剩下了旋轉)。
方案一在 Public Leaderboard 表現一般的方案成績為 0.7721 排在 55 名,但是卻在 Private Leaderboard 排在了第七名。方案二在 Public Leaderboard 上成績還不錯的方案 0.7751 在排名在 23,但是在 Private Leaderboard上37% 的測試集我的成績卻為 0.6954 排在了 300 多名。
寫在最后
由于本次比賽的數據集較小,很容易導致過擬合的現象。比賽結束的時候發現 Public leaderboard 成績還不錯,但是當 Private Leaderboard 出來后排名一落千丈,相比較而言,數據量大了的比賽絕大部分人排名都沒有變化,少數有 1~2 名的浮動在。
在這次比賽里的方案二由于 Public Leaderboard 上測試集占? 62%,測試集樣本較多,因此增加偽標簽的訓練使得它在 Public Leaderboard 上的成績增加很多,但是方案二發生了過擬合使得在 Private Leaderboard 上的成績下降就很明顯。
因此,深度學習網絡訓練到什么時候停止?在關注訓練集數量、質量以及分布等等因素的同時,更應該測試集(實際場景)的情況。否則常常會出現悲慘結局。另外,除了魔改模型,數據增強和半監督都是跳出魔改模型的好方法,能夠使得模型獲得更多的泛化能力。
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
- 上一篇: 怎样剪辑儿子在部队过生日
- 下一篇: 消息称 AMD 成台积电亚利桑那工厂新客