【深度学习】重新思考BatchNorm中的 “Batch”
這篇很有趣且很有用,激動的趕緊把文章看了一遍,不愧是FAIR,實驗看的太爽了。
之前對于Norm的研究主要在于改變Norm的維度,然后衍生出了BatchNorm、GroupNorm、InstanceNorm和LayerNorm等方法,但是除了BN外的其他Norm含義是確定,而BN的batch卻可以有多種采樣方式,本文就是為了探討BN的batch使用不同的采樣方式會有什么影響,堪稱BatchNorm圣經,建議全文背誦(ps:GN也是吳育昕的作品)。
本文總共4大核心實驗,每個核心實驗有多個子結論。
01
Motivation
BatchNorm現在已經廣泛的應用于CNN中。但是BN針對不同的場景使用時有許多細微的差異,如果選擇不當會降低模型的性能。BatchNorm相對于其他算子來說,主要的不同在于BN是對batch數據進行操作的。BN在batch數據中進行統計量計算,而其他算子一般都是獨立處理單個樣本的。因此影響BN的輸出不僅僅取決于單個樣本的性質,還取決于batch的采樣方式。
如圖所示,左右各舉例了三種batch采樣方式。其中左圖三種batch采樣方式分別為entire dataset、mini-batches和subset of mini-batches,右圖三種batch采樣方式分別為entire domain、each domain和mixture of each domain。
本文實驗證明了使用BN時不考慮batch的采樣方式會在許多方面產生負面影響,合理使用batch采樣方式會改善模型性能。
02
A Review of BatchNorm
簡單回顧一下BN的計算形式,這里以CNN中的BN為例。假設BN的輸入feature維度為??,逐通道統計量mean和std為??,那么BN的輸出y為:
??
假設batch的大小為N,mini-batch X的維度為??,那么mini-batch 統計量為??,定義為:
??
推理的時候,??用訓練集計算得到的統計量,定義為??。
不同于默認的BN設置,因為batch采樣方式主要影響的是統計量mean和std,本文將mean和std看成是一個逐通道分開計算的仿射變換(可以等價為一個1x1的depth-wise layer)。
03
Whole?Population?as aBatch
BN中統計量的計算默認使用EMA方法,但是作者實驗發現EMA會導致模型性能次優,然后提出了PreciseBN方法,近似將整個訓練集統計量作為一個batch。
Inaccuracy of EMA
EMA是指數滑動平均的縮寫,為了統計??,EMA在訓練過程中對統計量進行更新:
??
EMA方法導致次優解的原因有兩點:
1.當??太大時,統計量收斂速度變慢。
2.當??太小時,最近幾個mini-batches影響更大,統計量無法表示整個訓練集的統計量。
Towards Precise Population Statistics
為了得到整個訓練集更加精確的統計量,PreciseBN采用了兩點小技巧:
1.將相同模型用于多個mini-batches來收集batch統計量
2.將多個batch收集的統計量聚合成一個population統計量
比如有N個樣本需要通過數量為的Bmini-batch進行PreciseBN統計量計算,那么需要計算??次,統計量聚合公式為:
??
??
相比于EMA,PreciseBN有兩點重要的屬性:
1.PreciseBN的統計量是通過相同模型計算得到的,而EMA是通過多個歷史模型計算得到的。
2.PreciseBN的所有樣本的權重是相同的,而EMA不同樣本的權重是不同的。
100 samples of batch mean意思是相同epoch下模型對100個隨機batch統計量的結果。如圖所示,在訓練早期EMA的統計量不精確,會導致最終模型性能次優。由于滑動平均的計算方式導致EMA的統計量滯后于PrciseBN。
4個主要結論:
1.推理時使用PreciseBN會更加穩定。
2.大batch訓練對EMA影響更大。
3.PreciseBN只需要10^3~10^4個樣本可以得到近似最優。
4.小batch會產生統計量積累錯誤。
04
Batch in Training and Testing
BN在訓練和測試中行為不一致:訓練時,BN的統計量來自mini-batch;測試時,BN的統計量來自population。這部分主要探討了BN行為不一致對模型性能的影響,并且提出消除不一致的方法提升模型性能。
Effect of Normalization Batch Size
為了避免混淆,將SGD batch size或者total batch size定義為所有GPU上總的batch size大小,將normalization batch ? size定義為單個GPU上的batch size大小。
normalization batch size對training noise和train-test inconsistency有著直接影響:使用更大的batch,mini-batch統計量越接近population統計量,從而降低training noise和train-test inconsistency。
以下實驗的SGD batch size固定使用1024大小。
為了便于分析,作者觀察了3種不同評估方法的錯誤率:
1.在訓練集上對mini-batch統計量進行評估
2.在驗證集上對mini-batch統計量進行評估
3.在驗證集上對population統計量進行評估
Training noise:當normalization batch size非常小時,單個樣本會受到同一個min-batch樣本的嚴重影響,導致訓練精度較差,優化困難。
Generalization gap:隨著normalization batch size的增加,mini-batch的驗證集和訓練集的之間的泛化誤差會增大,這可能是由于training noise和train-test inconsistency沒有正則化。
Train-test inconsistency:在小batch下,mini-batch統計量和population統計量的不一致是影響性能的主要因素。當normalization batch size增大時,細微的不一致可以提供正則化效果減少驗證誤差。在mini-batch為32~128之間時,正則化達到平衡,模型性能最優。
為了保持train和test的BN統計量一致,作者提出了兩種方法來解決不一致問題,一種是推理的時候使用mini-batch統計量,另一種是訓練的時候使用population batch統計量。
Use Mini-batch in Inference
作者在Mask R-CNN上進行實驗,mini-batch的結果超過了population的結果,證明了在推理中使用mini-batch可以有效的緩解訓練測試不一致。(ps:不使用norm效果略差,使用GN效果更好)
Use Population Batch in Training
為了在訓練階段使用population統計量,作者采用FrozenBN的方法,FrozenBN使用population統計量。具體地,作者先選擇第80個epoch模型,然后將所有BN替換成FrozenBN,然后訓練20個epoch。
FrozenBN可以有效緩解訓練測試不一致,即使在小normalization batch size,也能達到比較好的性能。但是隨著normalization batch size增大,作者提出的兩種緩解不一致的方法都不如常規BN的結果。
05
Batch from Different Domains
BN的訓練過程可以看成是兩個獨立的階段:第一個階段是通過SGD學習features,第二個階段是由這些features得到population統計量。兩個階段分別稱為SGD training和population statistics training。
由于BN多了一個population統計階段,導致訓練和測試之間的domain shift。當數據來自多個doman時,SGD training、population statistics training和testing三個步驟的domain gap都會對泛化性造成影響。
實驗主要探究了兩種使用場景:第一種,模型在一個domain上進行訓練,然后在其他domain上進行測試;第二種,模型在多個domain上進行訓練。
Domain to Compute Population Statistics
作者實驗發現,當存在顯著的domain shift時,模型使用評估domain的population統計量會得到更好的結果,可以緩解訓練測試的不一致。
BatchNorm in Multi-Domain Training
為了對多個domain的情況進行實驗,作者將RetinaNet head中的BN統計量進行實驗設計。RetinaNet的head是5個feature層共享的,這意味著會接收來自5個不同分布或者domain的輸入進行訓練。
左圖的訓練形式非常簡單,head獨立作用于不同的feature層,都有自己獨立的統計量。右圖將所有輸入特征flatten然后concat在一起,統一進行統計量計算。兩種不同計算統計量的方式稱為domain-specific statistics和shared statistics。
最終實驗表明,SGD training、population statistics training和testing保持一致是非常重要的,并且全部使用domain-specific能取得最好的效果。(ps:不使用norm效果略差,使用GN效果更好)
06
Information Leakage within a?Batch
BN在使用中還存在一種information leakage現象,因為BN是對mini-batch的樣本計算統計量的,導致在樣本進行獨立預測時,會利用mini-batch內其他樣本的統計信息。
Exploit Patterns in Mini-batches
作者實驗發現,當使用random采樣的mini-batch統計量時,驗證誤差會增加,當使用population統計量時,驗證誤差會隨著epoch的增加逐漸增大,驗證了BN信息泄露問題的存在。
為了處理信息泄露問題,之前常見的作法是使用SyncBN,來弱化mini-batch內樣本之間的相關性。另一種解決方法是在進入head之前在GPU之間隨機打亂RoI features,這給每個GPU分配了一個隨機的樣本子集來進行歸一化,同時也削弱了min-batch樣本之間的相關性,如上圖所示。
實驗結果表明,shuffling和SyncBN都能有效地處理信息泄漏,使得head在測試時能夠很好地泛化。在速度方面,我們注意到shuffling需要更少的跨gpu同步,但是shuffling每次傳輸的數據比SyncBN多。因此,shuffling和SyncBN的相對效率跟具體模型架構相關。
Cheating in Contrastive Learning
在對比學習和度量學習時,訓練目標通常是在mini-batch下進行比較的,這種情況下BN也會造成信息泄露,導致模型在訓練期間作弊,之前的研究提出了很多不同方法來針對性解決對比學習和度量學習的信息泄露問題。
07
總結
本文從多個角度探討了BN的batch使用不同的采樣方式會有什么影響,并且做了非常詳盡的對比試驗,堪稱BatchNorm圣經,建議全文背誦。
另外,看完后最大的感觸是,BN不會用就別用,GN yyds。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯溫州大學《機器學習課程》視頻 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【深度学习】重新思考BatchNorm中的 “Batch”的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 温州大学《深度学习》课程课件(九、目标检
- 下一篇: 电脑公司win11旗舰版32位镜像v20