BN究竟起了什么作用?一个闭门造车的分析
作者丨蘇劍林
單位丨追一科技
研究方向丨NLP,神經網絡
個人主頁丨kexue.fm
BN,也就是 Batch Normalization [1],是當前深度學習模型(尤其是視覺相關模型)的一個相當重要的技巧,它能加速訓練,甚至有一定的抗過擬合作用,還允許我們用更大的學習率,總的來說頗多好處(前提是你跑得起較大的 batch size)。?
那BN究竟是怎么起作用呢?早期的解釋主要是基于概率分布的,大概意思是將每一層的輸入分布都歸一化到 N (0, 1) 上,減少了所謂的 Internal Covariate Shift,從而穩定乃至加速了訓練。這種解釋看上去沒什么毛病,但細思之下其實有問題的:不管哪一層的輸入都不可能嚴格滿足正態分布,從而單純地將均值方差標準化無法實現標準分布 N (0, 1) ;其次,就算能做到 N (0, 1) ,這種詮釋也無法進一步解釋其他歸一化手段(如 Instance Normalization、Layer Normalization)起作用的原因。?
在去年的論文 How Does Batch Normalization Help Optimization? [2] 里邊,作者明確地提出了上述質疑,否定了原來的一些觀點,并提出了自己關于 BN 的新理解:他們認為 BN 主要作用是使得整個損失函數的 landscape 更為平滑,從而使得我們可以更平穩地進行訓練。?
本文主要也是分享這篇論文的結論,但論述方法是筆者“閉門造車”地構思的。竊認為原論文的論述過于晦澀了,尤其是數學部分太不好理解,所以本文試圖盡可能直觀地表達同樣觀點。?
閱讀本文之前,請確保你已經清楚知道 BN 是什么,本文不再重復介紹 BN 的概念和流程。
一些基礎結論
在這部分內容中我們先給出一個核心的不等式,繼而推導梯度下降,并得到一些關于模型訓練的基本結論,為后面 BN 的分析鋪墊。?
核心不等式
假設函數 f(θ) 的梯度滿足 Lipschitz 約束( L 約束),即存在常數 L 使得下述恒成立:
那么我們有如下不等式:
證明并不難,定義輔助函數 f(θ+tΔθ), t∈[0,1],然后直接得到:
梯度下降
假設 f(θ) 是損失函數,而我們的目標是最小化 f(θ),那么這個不等式告訴我們很多信息。首先,既然是最小化,自然是希望每一步都在下降,即 f(θ+Δθ)<f(θ),而必然是非負的,所以要想下降的唯一選擇就是,這樣一個自然的選擇就是:
這里 η>0 是一個標量,即學習率。
可以發現,式 (4) 就是梯度下降的更新公式,所以這也就是關于梯度下降的一種推導了,而且這個推導過程所包含的信息量更為豐富,因為它是一個嚴格的不等式,所以它還可以告訴我們關于訓練的一些結論。?
Lipschitz約束
將梯度下降公式代入到不等式 (2) ,我們得到:
注意到,保證損失函數下降的一個充分條件是,為了做到這一點,要不就要 η 足夠小,要不就要 L 足夠小。但是 η 足夠小意味著學習速度會相當慢,所以更理想的情況是 L 能足夠小,降低了 L 就可以用更大的學習率了,能加快學習速度,這也是它的好處之一。
但 L 是 f(θ) 的內在屬性,因此只能通過調整 f 本身來降低 L。
BN是怎樣煉成的
本節將會表明:以降低神經網絡的梯度的 L 常數為目的,可以很自然地導出 BN。也就是說,BN 降低了神經網絡的梯度的 L 常數,從而使得神經網絡的學習更加容易,比如可以使用更大的學習率。而降低梯度的 L 常數,直觀來看就是讓損失函數沒那么“跌宕起伏”,也就是使得 landscape 更光滑的意思了。
注:我們之前就討論過 L 約束,之前我們討論的是神經網絡關于“輸入”滿足 L 約束,這導致了權重的譜正則和譜歸一化(請參考參數”滿足 L 約束,這導致了對輸入的各種歸一化手段,而 BN 是其中最自然的一種。
梯度分析
以監督學習為例,假設神經網絡表示為,損失函數取,那么我們要做的事情是:
也就是,所以:
順便說明一下,本文的每個記號均沒有加粗,但是根據實際情況不同它既有可能表示標量,也有可能表示向量。?
非線性假設
顯然, f(θ) 是一個非線性函數,它的非線性來源有兩個:
1. 損失函數一般是非線性的;
2. 神經網絡 h(x;θ) 中的激活函數是非線性的。
關于激活函數,當前主流的激活函數基本上都滿足一個特性:導數的絕對值不超過某個常數。我們現在來考慮這個特性能否推廣到損失函數中去,即(在整個訓練過程中)損失函數的梯度是否會被局限在某個范圍內?
看上去,這個假設通常都是不成立的,比如交叉熵是 ?log p,而它的導數是 ?1/p,顯然不可能被約束在某個有限范圍。但是,損失函數聯通最后一層的激活函數一起考慮時,則通常是滿足這個約束的。比如二分類是最后一層通常用 sigmoid 激活,這時候配合交叉熵就是:
這時候它關于 h 的梯度在 -1 到 1 之間。當然,確實存在一些情況是不成立的,比如回歸問題通常用 mse 做損失函數,并且最后一層通常不加激活函數,這時候它的梯度是一個線性函數,不會局限在一個有限范圍內。
這種情況下,我們只能寄望于模型有良好的初始化以及良好的優化器,使得在整個訓練過程中都比較穩定了。這個“寄望”看似比較強,但其實能訓練成功的神經網絡基本上都滿足這個“寄望”。
柯西不等式
我們的目的是探討滿足 L 約束的程度,并且探討降低這個 L 的方法。為此,我們先考慮最簡單的單層神經網絡(輸入向量,輸出標量) h(x;w,b)=g(?x,w?+b) ,這里的 g 是激活函數。這時候:
基于我們的假設,和都被閑置在某個范圍之內,所以可以看到偏置項 b 的梯度是很平穩的,它的更新也應當會是很平穩的。但是 w 的梯度不一樣,它跟輸入 x 直接相關。
關于 w 的梯度差,我們有:
將圓括號部分記為 λ(x,y;w,b,Δw),根據前面的討論,它被約束在某個范圍之內,這部分依然是平穩項,既然如此,我們不妨假設它天然滿足 L 約束,即:
這時候我們只需要關心好額外的 x。根據柯西不等式,我們有:
這樣一來,我們得到了與(當前層)參數無關的,如果我們希望降低 L 常數,最直接的方法是降低這一項。
減均值除標準差
要注意,雖然我們很希望降低梯度的 L 常數,但這是有前提的——必須在不會明顯降低原來神經網絡擬合能力的前提下,否則只需要簡單乘個 0 就可以讓 L 降低到 0 了,但這并沒有意義。?
式 (12) 的結果告訴我們,想辦法降低是個直接的做法,這意味著我們要對輸入 x 進行變換。然后根據剛才的“不降低擬合能力”的前提,最簡單并且可能有效的方法就是平移變換了,即我們考慮 x→x?μ,換言之,考慮適當的 μ 使得:
最小化。這只不過是一個二次函數的最小值問題,不難解得最優的 μ 是:
于是,我們得到:?
結論 1:將輸入減去所有樣本的均值,能降低梯度的 L 常數,是一個有利于優化又不降低神經網絡擬合能力的操作。
接著,我們考慮縮放變換,即,這里的 σ 是一個跟 x 大小一樣的向量,而除法則是逐位相除。這導致:
σ 是對 L 的一個最直接的縮放因子,但問題是縮放到哪里比較好?如果一味追求更小的 L,那直接 σ→∞ 就好了,但這樣的神經網絡已經完全沒有擬合能力了;但如果 σ 太小導致 L 過大,那又不利于優化。所以我們需要一個標準。
以什么為標準好呢?再次回去看梯度的表達式 (9),前面已經說了,偏置項的梯度不會被 x 明顯地影響,所以它似乎會是一個靠譜的標準。如果是這樣的話,那相當于將輸入 x 的這一項權重直接縮放為 1,那也就是說,變成了一個全 1 向量,再換言之:
這樣一來,一個相對自然的原則是將 σ 取為輸入的標準差。這時候,我們能感覺到除以標準差這一項,更像是一個自適應的學習率校正項,它一定程度上消除了不同層級的輸入對參數優化的差異性,使得整個網絡的優化更為“同步”,或者說使得神經網絡的每一層更為“平權”,從而更充分地利用好了整個神經網絡,減少了在某一層過擬合的可能性。當然,如果輸入的量級過大時,除以標準差這一項也有助于降低梯度的 L 常數。?
于是有結論:?
結論 2:將輸入(減去所有樣本的均值后)除以所有樣本的標準差,有類似自適應學習率的作用,使得每一層的更新更為同步,減少了在某一層過擬合的可能性,是一個提升神經網絡性能的操作。
推導窮,BN現
前面的推導,雖然表明上僅以單層神經網絡(輸入向量,輸出標量)為例子,但是結論已經有足夠的代表性了,因為多層神經網絡本質上也就是單層神經網絡的復合而已(關于這個論點,可以參考筆者舊作《從 Boosting 學習到神經網絡:看山是山?》[3] )。?
所以有了前面的兩個結論,那么 BN 基本就可以落實了:訓練的時候,每一層的輸出都減去均值除以標準差即可,不過由于每個 batch 的只是整體的近似,而期望 (14) , (16) 是全體樣本的均值和標準差,所以 BN 避免不了的是 batch size 大點效果才好,這對算力提出了要求。?
此外,我們還要維護一組變量,把訓練過程中的均值方差存起來,供預測時使用,這就是 BN 中通過滑動平均來統計的均值方差變量了。至于 BN 的標準設計中,減均值除標準差后還補充上的 β , γ 項,我認為僅是錦上添花作用,不是最必要的,所以也沒法多做解釋了。
簡單的總結
本文從優化角度分析了 BN 其作用的原理,所持的觀點跟 How Does Batch Normalization Help Optimization? 基本一致,但是所用的數學論證和描述方式個人認為會更簡單易懂寫。最終的結論是減去均值那一項,有助于降低神經網絡梯度的 L 常數,而除以標準差的那一項,更多的是起到類似自適應學習率的作用,使得每個參數的更新更加同步,而不至于對某一層、某個參數過擬合。?
當然,上述詮釋只是一些粗糙的引導,完整地解釋 BN 是一件很難的事情,BN 的作用更像是多種因素的復合結果,比如對于我們主流的激活函數來說, [?1,1] 基本上都是非線性較強的區間,所以將輸入弄成均值為 0、方差為 1,也能更充分地發揮激活函數的非線性能力,不至于過于浪費神經網絡的擬合能力。?
總之,神經網絡的理論分析都是很艱難的事情,遠不是筆者能勝任的,也就只能在這里寫寫博客,講講可有可無的故事來貽笑大方罷了。
相關鏈接
[1]?https://arxiv.org/abs/1502.03167[2]?https://arxiv.org/abs/1805.11604[3]?https://kexue.fm/archives/3873
點擊以下標題查看作者其他文章:?
基于DGCNN和概率圖的輕量級信息抽取模型
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
??來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
?
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點擊 |?閱讀原文?| 查看作者博客
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的BN究竟起了什么作用?一个闭门造车的分析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 为什么电脑桌面黑屏怎么办 电脑桌面怎么办
- 下一篇: 抢票 | AI未来说学术论坛第十期 视频