抓住训练集中真正有用的样本,提升模型整体性能!
文 | Severus
編 | 小戲
在任務中尋找到真正有用的訓練樣本,可以說一直是機器學習研究者們共同的訴求。畢竟,找到了真正有用的訓練樣本,排除掉訓練樣本中的雜質,無論最終是提升訓練模型的效率,還是提升了模型最終的測試性能,其意義都是非凡的。因此,相似的研究早在我們還要做特征工程的時期就已經層出不窮。
而到了 DNN 時代,在做任務的我們不需要人工特征工程了,DNN 模型直接用表示學習把“特征”安排的明明白白,數據就成了黑盒。不過,DNN 模型雖不可解釋,但 DNN 模型的結果一定反映了數據的現象,所以充分利用DNN模型訓練過程中的中間結果,也是可以得到有效的數據上的反映的,所謂“原湯化原食”的確是行之有效的思路。
今天要介紹的兩篇工作,則是以上述思路出發,從兩個不同的角度去提升模型的性能。[1]通過模型的中間結果,尋找出訓練集中真正重要的樣本,給模型訓練,從而做到刪減數據集之后,也能得到很好的測試精度;[2]通過反復訓練模型表現很差的那一部分樣本,從而提升模型的整體測試效果。
開局少一半數據,咱也依然能贏!
論文題目:
Deep Learning on a Data Diet: Finding Important Examples Early in Training
論文鏈接:
https://arxiv.org/abs/2107.07075
2018 年,Toneva et al.[3]從“遺忘”的角度去研究了數據的重要性。文中定義了“遺忘事件”,即在訓練中某一個時刻,更新參數前原本預測正確的樣本在更新參數后預測錯誤了,即認為發生了一次遺忘。作者據此定義了樣本的“遺忘分數”,用于量化樣本是否容易被遺忘。
由此,作者發現,一些很少被遺忘的樣本對最終測試精度的影響也很小,反倒是容易被遺忘的那些樣本會影響最終的評測效果。而通過這種方式,我們自然也能夠通過遺忘分數去刪減數據集,即留下那些容易被遺忘的數據,去掉那些不容易被遺忘的數據。
而由于這個方法需要在訓練中收集到遺忘的統計數據,最終的遺忘分數往往需要在訓練中后期計算完成。文章在 CIFAR-10 數據集上訓練了 200 個 epoch,在第 25 個 epoch 的時候開始得到比較好的遺忘分數,第 75 個 epoch 開始遺忘分數趨于穩定。
本文作者希望,在訓練早期,就可以確認數據的重要性,這樣既可以大幅度減少模型訓練時間和計算資源消耗,也可以對DNN模型的訓練過程,及數據起到的作用等提供重要的見解。
同樣,本文也想要找到訓練集中“重要”的數據,這里對“重要”的定義是:訓練樣本對 Loss 減少的貢獻,也就是說,在訓練過程中,利用這個樣本優化模型參數之后,其他樣本計算得到的 Loss 減少的量。這個定義非常直觀反映了這條樣本的泛化能力,通過擬合這一條樣本,模型能夠從中得到多少幫助其擬合其他樣本的信息。
那么,很直觀的想法就是,直接求取一條樣本計算得到的梯度的范數。由于現在 DNN 模型都是用梯度下降方法更新參數的,那么這個值可以直接反映出該條樣本對模型參數權重的影響程度,這個影響程度我們就可以看作它對模型擬合其他樣本的影響程度了。
樣本重要程度的定義
在訓練的 時刻,樣本 的重要程度(GraNd)為:
其中,,也就是該時刻,樣本的 Loss 的梯度。
下面我們從數學角度論證一下:
在 時刻,Minibatch 中的樣本 計算得到 Loss 的導數為:
根據鏈式法則,則:
而 是 時刻權重的變化,則有
而由于模型權重是由梯度下降更新的,則有:
從而,
那么實際上,我們需要理解,當從 中刪除一條訓練樣本時,會怎樣影響權重的變化?
設,對于所有樣本 ,存在一個常數 ,使得:
證明:根據上面的式子,導出,代入,則令 ,結果成立。
當然這個式子在推導過程中是有不嚴謹的地方,例如代入等式之后,公因數是不能提取的,所以 值實際有問題,但不等式成立,這部分在撰寫時尊重原作者。
訓練樣本的貢獻由上式限定下來,由于常數 不受具體樣本 影響,則只需要看樣本的 Loss 的梯度的范數即可,也就是 GraNd 分數。(3)式表明,GraNd 分數較小的樣本對模型區分其余樣本的的影響是有限的,那么就可以根據訓練樣本 GraNd 分數的排名,去裁剪樣本,越高的分數表明樣本對的影響越大。
對于任意輸入 ,設,表示第 k 個 Logit 的梯度,根據鏈式法則,則 GraNd 分數可以寫成如下形式:
當使用交叉熵loss時,有
當與 Logits 之間大體正交,且與 Logits 和訓練樣本 之間有相似的大小時,則 GraNd 可以通過“錯誤向量”的范數近似計算。
此處定義訓練樣本的 EL2N 分數(即錯誤向量)為。
而實際上,作者也將本文給出的兩種計算樣本重要度的分數 GraNd 和 EL2N 與[3]的遺忘分數進行了比較,分析得出遺忘分數較高的樣本,GraNd 分數也會較高,這樣看來,二者所選擇的重要樣本其實也是類似的。
實驗效果
在確定了計算重要程度的方法之后,作者直接在訓練早期,分別計算遺忘分數、GraNd 及 EL2N ,然后利用計算的結果刪減了數據集,之后訓練模型,測試結果如下:
數據集和模型如上。其中,CIFAR10 保留了 50% 的數據,CINIC10 保留了 60% 的數據,CIFAR100 保留了75%的數據??梢钥吹?#xff0c;基本驗證了作者在前文中的猜想:訓練到中后期,通過三種計算方法裁剪數據的表現是各有優劣的,而 GraNd 和 EL2N 的確可以在訓練早期就得到不錯的結果。而且按上述比例裁剪了數據集之后,相比于使用所有的數據,測試精度損失的不是很大。
同時,作者也對比了分別使用 200 個 epoch 得到的遺忘分數,以及 20 個 epoch 得到的 GraNd 和 EL2N 計算樣本重要性,以不同的比例裁剪數據后的測試結果:
3個結果也分別是 CIFAR10 + ResNet18、CINIC10 + ResNet18 和 CIFAR100 + ResNet50。可以看到,首先相比于隨機裁剪,的確三種裁剪方式都展現了相當的能力,甚至在裁剪數據比較少的時候,利用GraNd和遺忘分數裁剪后的數據訓練,測試精度還超過了使用整個數據集訓練,這里我猜測,在裁剪比例比較少的時候,被裁剪掉的數據主要是離群點,所以測試精度相比于全數據訓練會稍高。
至此,作者提出的主要貢獻,即在訓練早期即可得到不錯的樣本重要度評估,以及利用它裁剪訓練數據之后,依然能保持不錯的測試精度都得到了驗證,而在論文中,作者也展示了使用樣本重要度可以做到其他的什么事情,以及利用一些補充實驗從多種角度分析了兩種計算重要程度的方法的性能,這里就不再贅述了,感興趣的讀者可以閱讀原文。
所以無論是計算遺忘分數的方法,還是本文提出的 GraNd 和 EL2N,實際在固定任務場景之下,即固定分布、固定范圍內是相當有價值的。
模型總出錯怎么辦?反復教它,直到它會
當我們訓練好一個模型之后,在測試過程中,我們會發現,總是有一些“疑難雜癥”一樣的樣本,怎么樣訓練都無法訓練正確,而實際上,我們也知道,這些樣本可能是訓練樣本中比較邊緣的部分(假設訓練集和測試集符合獨立同分布假設,即所有測試樣本均處于訓練集的分布之中,如超出了訓練集分布,則怎么也解決不了)。訓練的過程則是模型不斷擬合訓練樣本分布的過程,那么這種邊緣的東西,則會成為模型的疑難雜癥。
雖然機器學習研究中一直假設訓練樣本的分布就是真實數據的分布,可是我們也不得不承認,抽樣空間和真實的空間就是存在分布上的偏差,怎么樣都存在,這些“疑難雜癥”的存在正是表明了訓練集的分布和真實數據的分布存在的 Gap ,那么自然也就有了一個研究方向:在已有訓練集上,找到擬合的分布最接近于真實數據分布的參數,即分布魯棒性優化(Distributionally Robust Optimization, DRO),其基本思路是在訓練過程中按照分布將訓練樣本分成若干組,最小化最差的組的 Loss,從而去提升模型的效果。
而本文作者提到,DRO 方法雖然是可行的,但是它要對訓練樣本分組,這個成本還是略大的,能不能不去對訓練樣本分組,而是找到驗證集中那些比較差的樣本,反反復復教給模型,從而讓模型的效果更好呢?
問題定義
對于一個分類問題,輸入為 ,類別標簽 ,集合中有 n 個訓練樣本 ,目標是訓練得到模型。
在預定義好的組 之間評估模型的性能,每個訓練樣本 都屬于組 ,分類器的最壞組錯誤的定義如下:
其中,。
而訓練樣本中想得到這樣的組成本還是比較大的,但是在測試期間,使用少量的 m 個驗證集及在驗證集上預定義的若干個組,得到較好的最差 case 集合,用于調整超參,優化模型。
而驗證集的分組則是使用樣本中本身存在的一些屬性 與類別標簽的關聯來劃分的,即 ,如下圖中例子,分類水生鳥類和陸生鳥類,觀察數據發現,圖片的背景和標簽存在相關關系,則分為4類:
JTT:訓練兩次就好了
本文給出的方法則是兩階段的方法:首先,我們都知道,統計模型更傾向于去學習簡單的關聯(例如在水上的水生鳥類,在陸地上的陸生鳥類),而復雜的關聯(例如在水上的陸生鳥類,在陸地上的水生鳥類)學習的就比較差了,那么第一階段,直接使用訓練集訓練一個識別模型,直接找到當前模型的“易錯題集合”,即:
之后,則是增大“易錯題集合”中樣本的 Loss 權重,加強記憶,繼續訓練模型:
其中,是一個超參數。方法非常直觀,就是將易錯組加強記憶一遍,最終得到一個不錯的模型。
那么我們看一下最終的訓練結果,作者在圖像兩個圖像分類任務和兩個 NLP 任務上分別嘗試了效果,可以看到,在對比中情況較差的組的效果的確改善很多:
相比于要對整個訓練集分組的 DRO 方法,這個方法的確成本上小了很多,且相比于其他類似的方法(論文中有簡單介紹它所對比的幾種方法),它的提升也相對比較高,可以說是比較符合直覺,且效果比較好的方法。這個方法與分組時所定義的屬性(即)非常相關,例如在水生鳥類和陸生鳥類分類中,使用了圖片的背景,在照片男女性別分類中,使用了頭發顏色;在 NLI 任務中,使用了文本中是否含有否定詞語;在侮辱性評論分類任務中,使用了文本中是否含有性別描述詞。
可以看出,雖然不需要使用模型去計算分組了,但也需要人為地根據數據分布來對原本數據進行歸組,而如果歸組出現問題,則我想對最終的效果影響也不會小。而且,模型去過度關注預測錯誤的樣本,實際上對已經學到的正確的樣本似乎也會造成一定的損失,上文中可以看到,相比于一般方法,4 種改善錯誤的方法在整體的精度上都有了一定的損失,而想得到均衡的效果,在劃分集合上和超參選擇上都有很多的講究。
而且,會不會所謂最差的集合中,實際上是存在部分錯誤,或者離群點的呢?過度去擬合它,是否造成了過擬合,或者引入了噪聲呢?我們不得而知。
當然,文章中仍然有大量的對比分析及消融實驗,本文也不再贅述。
這篇工作實際上是部分利用了人的先驗知識,用更偏向直覺的方法,使用更簡單的算法去解決分布魯棒性優化(DRO)問題,其所關注也是模型的泛化能力。其基本動因就是,模型在某些樣本上的效果非常差,則說明現在所擬合的分布是有偏的,那么就讓模型的分布偏移,去包含那些相對“離群”的樣本,但由于盤子也只有那么大,偏向了離群的樣本,則也會舍去另一個邊緣的樣本。從最終結果上來看,雖然人為劃分的最差集合上效果變好了,但整體上變差了,實際上個人認為也沒有達到 DRO 想要達到的理想狀態(實際上我們可以看到,發表于 ICLR2020 的 Group DRO的整體效果看上去也好得多)。
固定任務之下,似乎我們也只能使用這種權衡的方式來糾偏,而如果我們面向的是海量數據,則我們也會有更多的選擇。
萌屋作者:Severus
Severus,在某廠工作的老程序員,主要從事自然語言理解方向,資深死宅,日常憤青,對個人覺得難以理解的同行工作都采取直接吐槽的態度。筆名取自哈利波特系列的斯內普教授,覺得自己也像他那么自閉、刻薄、陰陽怪氣,也向往他為愛而偉大。
作品推薦
深度學習,路在何方?
數據還是模型?人類知識在深度學習里還有用武之地嗎?
在錯誤的數據上,刷到 SOTA 又有什么意義?
后臺回復關鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
后臺回復關鍵詞【頂會】
獲取ACL、CIKM等各大頂會論文集!
?
[1] Paul M, Ganguli S, Dziugaite G K. Deep Learning on a Data Diet: Finding Important Examples Early in Training[J]. arXiv preprint arXiv:2107.07075, 2021.
[2] Liu E Z, Haghgoo B, Chen A S, et al. Just Train Twice: Improving Group Robustness without Training Group Information[C]//International Conference on Machine Learning. PMLR, 2021: 6781-6792.
[3] Toneva M, Sordoni A, Combes R T, et al. An empirical study of example forgetting during deep neural network learning[J]. arXiv preprint arXiv:1812.05159, 2018.
總結
以上是生活随笔為你收集整理的抓住训练集中真正有用的样本,提升模型整体性能!的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: positional encoding位
- 下一篇: 无内鬼,来点ICML/ACL审稿人笑话