Batch Normalization深入理解
Batch Normalization深入理解
1. BN的提出背景是什么?
統計學習中的一個很重要的假設就是輸入的分布是相對穩定的。如果這個假設不滿足,則模型的收斂會很慢,甚至無法收斂。所以,對于一般的統計學習問題,在訓練前將數據進行歸一化或者白化(whitening)是一個很常用的trick。
但這個問題在深度神經網絡中變得更加難以解決。在神經網絡中,網絡是分層的,可以把每一層視為一個單獨的分類器,將一個網絡看成分類器的串聯。這就意味著,在訓練過程中,隨著某一層分類器的參數的改變,其輸出的分布也會改變,這就導致下一層的輸入的分布不穩定。分類器需要不斷適應新的分布,這就使得模型難以收斂。
對數據的預處理可以解決第一層的輸入分布問題,而對于隱藏層的問題無能為力,這個問題就是Internal Covariate Shift。而Batch Normalization其實主要就是在解決這個問題。
除此之外,一般的神經網絡的梯度大小往往會與參數的大小相關(仿射變換),且隨著訓練的過程,會產生較大的波動,這就導致學習率不宜設置的太大。Batch Normalization使得梯度大小相對固定,一定程度上允許我們使用更高的學習率。
(左)沒有任何歸一化,(右)應用了batch normalization
2. BN工作原理是什么?
假定我們的輸入是一個大小為 N 的mini-batch?,通過下面的四個式子計算得到的y? 就是Batch Normalization(BN)的值
數據看起來像高斯分布。
首先,由(2.1)和(2.2)得到mini-batch的均值和方差,之后進行(2.3)的歸一化操作,在分母加上一個小的常數是為了避免出現除0操作.整個過程中,只有最后的(2.4)引入了額外參數γ和β,他們的size都為特征長度,與 xi 相同。
BN層通常添加在隱藏層的激活函數之前,線性變換之后。如果我們把(2.4)和之后的激活函數放在一起看,可以將他們視為一層完整的神經網絡(線性+激活)。(注意BN的線性變換和一般隱藏層的線性變換仍有區別,前者是element-wise的,后者是矩陣乘法。)
此時,??可以視為這一層網絡的輸入,而??是擁有固定均值和方差的。這就解決了Covariate Shift.
另外,? y還具有保證數據表達能力的作用。?在normalization的過程中,不可避免的會改變自身的分布,而這會導致學習到的特征的表達能力有一定程度的丟失。通過引入參數γ和β,極端情況下,網絡可以將γ和β訓練為原分布的標準差和均值來恢復數據的原始分布。這樣保證了引入BN,不會使效果更差。
3. BN實現方法是什么?
我們將Batch Normalization分成正向(只包括訓練)和反向兩個過程。
正向過程的參數x是一個mini-batch的數據,gamma和beta是BN層的參數,bn_param是一個字典,包括??的取值和用于inference的?和?的移動平均值,最后返回BN層的輸出y,會在反向過程中用到的中間變量cache,以及更新后的移動平均。
反向過程的參數是來自上一層的誤差信號dout,以及正向過程中存儲的中間變量cache,最后返回???的偏導數。
實現與推導的不同在于,實現是對整個batch的操作。
import numpy as npdef batchnorm_forward(x, gamma, beta, bn_param):# read some useful parameterN, D = x.shapeeps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))# BN forward passsample_mean = x.mean(axis=0)sample_var = x.var(axis=0)x_ = (x - sample_mean) / np.sqrt(sample_var + eps)out = gamma * x_ + beta# update moving averagerunning_mean = momentum * running_mean + (1-momentum) * sample_meanrunning_var = momentum * running_var + (1-momentum) * sample_varbn_param['running_mean'] = running_meanbn_param['running_var'] = running_var# storage variables for backward passcache = (x_, gamma, x - sample_mean, sample_var + eps)return out, cachedef batchnorm_backward(dout, cache):# extract variablesN, D = dout.shapex_, gamma, x_minus_mean, var_plus_eps = cache# calculate gradientsdgamma = np.sum(x_ * dout, axis=0)dbeta = np.sum(dout, axis=0)dx_ = np.matmul(np.ones((N,1)), gamma.reshape((1, -1))) * doutdx = N * dx_ - np.sum(dx_, axis=0) - x_ * np.sum(dx_ * x_, axis=0)dx *= (1.0/N) / np.sqrt(var_plus_eps)return dx, dgamma, dbeta?
4. BN優點是什么?
-
更快的收斂。
-
降低初始權重的重要性。
-
魯棒的超參數。
-
需要較少的數據進行泛化。
5. BN缺點是什么?
- 在使用小batch size的時候不穩定: batch normalization必須計算平均值和方差,以便在batch中對之前的輸出進行歸一化。如果batch大小比較大的話,這種統計估計是比較準確的,而隨著batch大小的減少,估計的準確性持續減小。
以上是ResNet-50的驗證錯誤圖。可以推斷,如果batch大小保持為32,它的最終驗證誤差在23左右,并且隨著batch大小的減小,誤差會繼續減小(batch大小不能為1,因為它本身就是平均值)。損失有很大的不同(大約10%)。
如果batch大小是一個問題,為什么我們不使用更大的batch?我們不能在每種情況下都使用更大的batch。在finetune的時候,我們不能使用大的batch,以免過高的梯度對模型造成傷害。在分布式訓練的時候,大的batch最終將作為一組小batch分布在各個實例中。
-
導致訓練時間的增加:NVIDIA和卡耐基梅隆大學進行的實驗結果表明,“盡管Batch Normalization不是計算密集型,而且收斂所需的總迭代次數也減少了。”但是每個迭代的時間顯著增加了,而且還隨著batch大小的增加而進一步增加。
batch normalization消耗了總訓練時間的1/4。原因是batch normalization需要通過輸入數據進行兩次迭代,一次用于計算batch統計信息,另一次用于歸一化輸出。
-
訓練和推理時不一樣的結果:例如,在真實世界中做“物體檢測”。在訓練一個物體檢測器時,我們通常使用大batch(YOLOv4和Faster-RCNN都是在默認batch大小= 64的情況下訓練的)。但在投入生產后,這些模型的工作并不像訓練時那么好。這是因為它們接受的是大batch的訓練,而在實時情況下,它們的batch大小等于1,因為它必須一幀幀處理。考慮到這個限制,一些實現傾向于基于訓練集上使用預先計算的平均值和方差。另一種可能是基于你的測試集分布計算平均值和方差值。
-
對于在線學習不好:在線學習是一種學習技術,在這種技術中,系統通過依次向其提供數據實例來逐步接受訓練,可以是單獨的,也可以是通過稱為mini-batch的小組進行。每個學習步驟都是快速和便宜的,所以系統可以在新的數據到達時實時學習。
由于它依賴于外部數據源,數據可能單獨或批量到達。由于每次迭代中batch大小的變化,對輸入數據的尺度和偏移的泛化能力不好,最終影響了性能。
-
對于循環神經網絡不好:
雖然batch normalization可以顯著提高卷積神經網絡的訓練和泛化速度,但它們很難應用于遞歸結構。batch normalization可以應用于RNN堆棧之間,其中歸一化是“垂直”應用的,即每個RNN的輸出。但是它不能“水平地”應用,例如在時間步之間,因為它會因為重復的重新縮放而產生爆炸性的梯度而傷害到訓練。
[^注]: 一些研究實驗表明,batch normalization使得神經網絡容易出現對抗漏洞,但我們沒有放入這一點,因為缺乏研究和證據。
6. 可替換的方法:
在batch normalization無法很好工作的情況下,有幾種替代方法。
-
Layer Normalization
-
Instance Normalization
-
Group Normalization (+ weight standardization)
-
Synchronous Batch Normalization
7.參考資料
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Deriving the Gradient for the Backward Pass of Batch Normalization
CS231n Convolutional Neural Networks for Visual Recognition
https://towardsdatascience.com/curse-of-batch-normalization-8e6dd20bc304
總結
以上是生活随笔為你收集整理的Batch Normalization深入理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python、numpy,keras,t
- 下一篇: Linux绝对路径和相对路径简单介绍