训练集样本不平衡问题对CNN的影响
轉載自??訓練集樣本不平衡問題對CNN的影響
訓練集樣本不平衡問題對CNN的影響
本文首發于知乎專欄“ai insight”!
卷積神經網絡(CNN)可以說是目前處理圖像最有力的工具了。
而在機器學習分類問題中,樣本不平衡又是一個經常遇到的問題。最近在使用CNN進行圖片分類時,發現CNN對訓練集樣本不平衡問題很敏感。
在網上搜索了一下,發現http://www.diva-portal.org/smash/get/diva2:811111/FULLTEXT01.pdf這篇文章對這個問題已經做了比較細致的探索。于是就把它簡單整理了一下,相關的記錄如下。
?
1、實驗數據與使用的網絡
所謂樣本不平衡,就是指在分類問題中,每一類對應的樣本的個數不同,而且差別較大。
這樣的不平衡的樣本往往使機器學習算法的表現變得比較差。那么在CNN中又有什么樣的影響呢?作者選用了CIFAR-10作為數據源來生成不平衡的樣本數據。
CIFAR-10是一個簡單的圖像分類數據集。共有10類(airplane,automobile,bird,cat,deer,dog, frog,horse,ship,truck),每一類含有5000張訓練圖片,1000張測試圖片。
CIFAR-10樣例如圖:
訓練時,選擇的網絡是這里的CIFAR-10訓練網絡和參數(來自Alex Krizhevsky)。這個網絡含有3個卷積層,還有10個輸出結點。
之所以不選用效果更好的CNN網絡,是因為我們的目的是在實驗時訓練很多次進行比較,而不是獲得多么好的性能。
而這個CNN網絡因為比較淺,訓練速度比較快,比較符合我們的要求。
?
2、類別不平衡數據的生成
直接從原始CIFAR-10采樣,通過控制每一類采樣的個數,就可以產生類別不平衡的訓練數據。如下表所示:
這里的每一行就表示“一份”訓練數據。而每個數字就表示這個類別占這“一份”訓練數據的百分比。
Dist. 1:類別平衡,每一類都占用10%的數據。
Dist. 2、Dist. 3:一部分類別的數據比另一部分多。
Dist. 4、Dist 5:只有一類數據比較多。
Dist. 6、Dist 7:只有一類數據比較少。
Dist. 8: 數據個數呈線性分布。
Dist. 9:數據個數呈指數級分布。
Dist. 10、Dist. 11:交通工具對應的類別中的樣本數都比動物的多
對每一份訓練數據都進行訓練,測試時用的測試集還是每類1000個的原始測試集,保持不變。
?
3、類別不平衡數據的訓練結果
以上數據經過訓練后,每一類對應的預測正確率如下:
第一列Total表示總的正確率,下面是每一類分別的正確率。
從實驗結果中可以看出:
類別完全平衡時,結果最好。
類別“越不平衡”,效果越差。比如Dist. 3就比Dist. 2更不平衡,效果就更差。同樣的對比還有Dist. 4和Dist. 5,Dist. 8和Dist. 9。其中Dist. 5和Dist. 9更是完全訓練失敗了。
?
4、過采樣訓練的結果
作者還實驗了“過采樣”(oversampling)這種平衡數據集的方法。
這里的過采樣方法是:對每一份數據集中比較少的類,直接復制其中的圖片增大樣本數量直至所有類別平衡。
再次訓練,進行測試,結果為:
可以發現過采樣的效果非常好,基本與平衡時候的表現一樣了。
過采樣前后效果對比,可以發現過采樣效果非常好:
?
5、總結
CNN確實對訓練樣本中類別不平衡的問題很敏感。
平衡的類別往往能獲得最佳的表現,而不平衡的類別往往使模型的效果下降。如果訓練樣本不平衡,可以使用過采樣平衡樣本之后再訓練。
這確實是一個“經驗主義”的結論,但多少給我們平常訓練CNN模型帶來一些啟發和幫助。
總結
以上是生活随笔為你收集整理的训练集样本不平衡问题对CNN的影响的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国第一个自然保护区是 是哪个
- 下一篇: 适合圆脸男生的发型设计