Pytorch的BatchNorm层使用中容易出现的问题
前言
本文主要介紹在pytorch中的Batch Normalization的使用以及在其中容易出現(xiàn)的各種小問(wèn)題,本來(lái)此文應(yīng)該歸屬于[1]中的,但是考慮到此文的篇幅可能會(huì)比較大,因此獨(dú)立成篇,希望能夠幫助到各位讀者。如有謬誤,請(qǐng)聯(lián)系指出,如需轉(zhuǎn)載,請(qǐng)注明出處,謝謝。
? \nabla ? 聯(lián)系方式:
e-mail: FesianXu@gmail.com
QQ: 973926198
github: https://github.com/FesianXu
知乎專欄: 計(jì)算機(jī)視覺/計(jì)算機(jī)圖形理論與應(yīng)用
微信公眾號(hào):
qrcode
Batch Normalization,批規(guī)范化
Batch Normalization(簡(jiǎn)稱為BN)[2],中文翻譯成批規(guī)范化,是在深度學(xué)習(xí)中普遍使用的一種技術(shù),通常用于解決多層神經(jīng)網(wǎng)絡(luò)中間層的協(xié)方差偏移(Internal Covariate Shift)問(wèn)題,類似于網(wǎng)絡(luò)輸入進(jìn)行零均值化和方差歸一化的操作,不過(guò)是在中間層的輸入中操作而已,具體原理不累述了,見[2-4]的描述即可。
在BN操作中,最重要的無(wú)非是這四個(gè)式子:
注意到這里的最后一步也稱之為仿射(affine),引入這一步的目的主要是設(shè)計(jì)一個(gè)通道,使得輸出output至少能夠回到輸入input的狀態(tài)(當(dāng) γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0時(shí))使得BN的引入至少不至于降低模型的表現(xiàn),這是深度網(wǎng)絡(luò)設(shè)計(jì)的一個(gè)套路。
整個(gè)過(guò)程見流程圖,BN在輸入后插入,BN的輸出作為規(guī)范后的結(jié)果輸入的后層網(wǎng)絡(luò)中。
好了,這里我們記住了,在BN中,一共有這四個(gè)參數(shù)我們要考慮的:
??? γ , β \gamma, \beta γ,β:分別是仿射中的 w e i g h t \mathrm{weight} weight和 b i a s \mathrm{bias} bias,在pytorch中用weight和bias表示。
??? μ B \mu_{\mathcal{B}} μB?和 σ B 2 \sigma_{\mathcal{B}}^2 σB2?:和上面的參數(shù)不同,這兩個(gè)是根據(jù)輸入的batch的統(tǒng)計(jì)特性計(jì)算的,嚴(yán)格來(lái)說(shuō)不算是“學(xué)習(xí)”到的參數(shù),不過(guò)對(duì)于整個(gè)計(jì)算是很重要的。在pytorch中,這兩個(gè)統(tǒng)計(jì)參數(shù),用running_mean和running_var表示[5],這里的running指的就是當(dāng)前的統(tǒng)計(jì)參數(shù)不一定只是由當(dāng)前輸入的batch決定,還可能和歷史輸入的batch有關(guān),詳情見以下的討論,特別是參數(shù)momentum那部分。
Update 2020/3/16:
因?yàn)锽N層的考核,在工作面試中實(shí)在是太常見了,在本文順帶補(bǔ)充下BN層的參數(shù)的具體shape大小。
以圖片輸入作為例子,在pytorch中即是nn.BatchNorm2d(),我們實(shí)際中的BN層一般是對(duì)于通道進(jìn)行的,舉個(gè)例子而言,我們現(xiàn)在的輸入特征(可以視為之前討論的batch中的其中一個(gè)樣本的shape)為 x ∈ R C × W × H \mathbf{x} \in \mathbb{R}^{C \times W \times H} x∈RC×W×H(其中C是通道數(shù),W是width,H是height),那么我們的 μ B ∈ R C \mu_{\mathcal{B}} \in \mathbb{R}^{C} μB?∈RC,而方差 σ B 2 ∈ R C \sigma^{2}_{\mathcal{B}} \in \mathbb{R}^C σB2?∈RC。而仿射中 w e i g h t , γ ∈ R C \mathrm{weight}, \gamma \in \mathbb{R}^{C} weight,γ∈RC以及 b i a s , β ∈ R C \mathrm{bias}, \beta \in \mathbb{R}^{C} bias,β∈RC。我們會(huì)發(fā)現(xiàn),這些參數(shù),無(wú)論是學(xué)習(xí)參數(shù)還是統(tǒng)計(jì)參數(shù)都會(huì)通道數(shù)有關(guān),其實(shí)在pytorch中,通道數(shù)的另一個(gè)稱呼是num_features,也即是特征數(shù)量,因?yàn)椴煌ǖ赖奶卣餍畔⑼ǔ:懿幌嗤?#xff0c;因此需要隔離開通道進(jìn)行處理。
有些朋友可能會(huì)認(rèn)為這里的weight應(yīng)該是一個(gè)張量,而不應(yīng)該是一個(gè)矢量,其實(shí)不是的,這里的weight其實(shí)應(yīng)該看成是 對(duì)輸入特征圖的每個(gè)通道得到的歸一化后的 x ^ \hat{\mathbf{x}} x^進(jìn)行尺度放縮的結(jié)果,因此對(duì)于一個(gè)通道數(shù)為 C C C的輸入特征圖,那么每個(gè)通道都需要一個(gè)尺度放縮因子,同理,bias也是對(duì)于每個(gè)通道而言的。這里切勿認(rèn)為 y i ← γ x ^ i + β y_i \leftarrow \gamma \hat{x}_i+\beta yi?←γx^i?+β這一步是一個(gè)全連接層,他其實(shí)只是一個(gè)尺度放縮而已。關(guān)于這些參數(shù)的形狀,其實(shí)可以直接從pytorch源代碼看出,這里截取了_NormBase層的部分初始代碼,便可一見端倪。
class _NormBase(Module):
??? """Common base of _InstanceNorm and _BatchNorm"""
??? _version = 2
??? __constants__ = ['track_running_stats', 'momentum', 'eps',
???????????????????? 'num_features', 'affine']
??? def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
???????????????? track_running_stats=True):
??????? super(_NormBase, self).__init__()
??????? self.num_features = num_features
??????? self.eps = eps
??????? self.momentum = momentum
??????? self.affine = affine
??????? self.track_running_stats = track_running_stats
??????? if self.affine:
??????????? self.weight = Parameter(torch.Tensor(num_features))
??????????? self.bias = Parameter(torch.Tensor(num_features))
??????? else:
??????????? self.register_parameter('weight', None)
??????????? self.register_parameter('bias', None)
??????? if self.track_running_stats:
??????????? self.register_buffer('running_mean', torch.zeros(num_features))
??????????? self.register_buffer('running_var', torch.ones(num_features))
??????????? self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
??????? else:
??????????? self.register_parameter('running_mean', None)
??????????? self.register_parameter('running_var', None)
??????????? self.register_parameter('num_batches_tracked', None)
??????? self.reset_parameters()
?
在Pytorch中使用
Pytorch中的BatchNorm的API主要有:
torch.nn.BatchNorm1d(num_features,
???????????????????? eps=1e-05,
???????????????????? momentum=0.1,
???????????????????? affine=True,
???????????????????? track_running_stats=True)
?
一般來(lái)說(shuō)pytorch中的模型都是繼承nn.Module類的,都有一個(gè)屬性trainning指定是否是訓(xùn)練狀態(tài),訓(xùn)練狀態(tài)與否將會(huì)影響到某些層的參數(shù)是否是固定的,比如BN層或者Dropout層。通常用model.train()指定當(dāng)前模型model為訓(xùn)練狀態(tài),model.eval()指定當(dāng)前模型為測(cè)試狀態(tài)。
同時(shí),BN的API中有幾個(gè)參數(shù)需要比較關(guān)心的,一個(gè)是affine指定是否需要仿射,還有個(gè)是track_running_stats指定是否跟蹤當(dāng)前batch的統(tǒng)計(jì)特性。容易出現(xiàn)問(wèn)題也正好是這三個(gè)參數(shù):trainning,affine,track_running_stats。
??? 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個(gè),如果affine=False,則 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0,并且不能學(xué)習(xí)被更新。一般都會(huì)設(shè)置成affine=True[10]
??? trainning和track_running_stats,track_running_stats=True表示跟蹤整個(gè)訓(xùn)練過(guò)程中的batch的統(tǒng)計(jì)特性,得到方差和均值,而不只是僅僅依賴與當(dāng)前輸入的batch的統(tǒng)計(jì)特性。相反的,如果track_running_stats=False那么就只是計(jì)算當(dāng)前輸入的batch的統(tǒng)計(jì)特性中的均值和方差了。當(dāng)在推理階段的時(shí)候,如果track_running_stats=False,此時(shí)如果batch_size比較小,那么其統(tǒng)計(jì)特性就會(huì)和全局統(tǒng)計(jì)特性有著較大偏差,可能導(dǎo)致糟糕的效果。
一般來(lái)說(shuō),trainning和track_running_stats有四種組合[7]
??? trainning=True, track_running_stats=True。這個(gè)是期望中的訓(xùn)練階段的設(shè)置,此時(shí)BN將會(huì)跟蹤整個(gè)訓(xùn)練過(guò)程中batch的統(tǒng)計(jì)特性。
??? trainning=True, track_running_stats=False。此時(shí)BN只會(huì)計(jì)算當(dāng)前輸入的訓(xùn)練batch的統(tǒng)計(jì)特性,可能沒法很好地描述全局的數(shù)據(jù)統(tǒng)計(jì)特性。
??? trainning=False, track_running_stats=True。這個(gè)是期望中的測(cè)試階段的設(shè)置,此時(shí)BN會(huì)用之前訓(xùn)練好的模型中的(假設(shè)已經(jīng)保存下了)running_mean和running_var并且不會(huì)對(duì)其進(jìn)行更新。一般來(lái)說(shuō),只需要設(shè)置model.eval()其中model中含有BN層,即可實(shí)現(xiàn)這個(gè)功能。[6,8]
??? trainning=False, track_running_stats=False 效果同(2),只不過(guò)是位于測(cè)試狀態(tài),這個(gè)一般不采用,這個(gè)只是用測(cè)試輸入的batch的統(tǒng)計(jì)特性,容易造成統(tǒng)計(jì)特性的偏移,導(dǎo)致糟糕效果。
同時(shí),我們要注意到,BN層中的running_mean和running_var的更新是在forward()操作中進(jìn)行的,而不是optimizer.step()中進(jìn)行的,因此如果處于訓(xùn)練狀態(tài),就算你不進(jìn)行手動(dòng)step(),BN的統(tǒng)計(jì)特性也會(huì)變化的。如
model.train() # 處于訓(xùn)練狀態(tài)
for data, label in self.dataloader:
?? ?pred = model(data) ?
?? ?# 在這里就會(huì)更新model中的BN的統(tǒng)計(jì)特性參數(shù),running_mean, running_var
?? ?loss = self.loss(pred, label)
?? ?# 就算不要下列三行代碼,BN的統(tǒng)計(jì)特性參數(shù)也會(huì)變化
?? ?opt.zero_grad()
?? ?loss.backward()
?? ?opt.step()
?
這個(gè)時(shí)候要將model.eval()轉(zhuǎn)到測(cè)試階段,才能固定住running_mean和running_var。有時(shí)候如果是先預(yù)訓(xùn)練模型然后加載模型,重新跑測(cè)試的時(shí)候結(jié)果不同,有一點(diǎn)性能上的損失,這個(gè)時(shí)候十有八九是trainning和track_running_stats設(shè)置的不對(duì),這里需要多注意。 [8]
假設(shè)一個(gè)場(chǎng)景,如下圖所示:
此時(shí)為了收斂容易控制,先預(yù)訓(xùn)練好模型model_A,并且model_A內(nèi)含有若干BN層,后續(xù)需要將model_A作為一個(gè)inference推理模型和model_B聯(lián)合訓(xùn)練,此時(shí)就希望model_A中的BN的統(tǒng)計(jì)特性值running_mean和running_var不會(huì)亂變化,因此就必須將model_A.eval()設(shè)置到測(cè)試模式,否則在trainning模式下,就算是不去更新該模型的參數(shù),其BN都會(huì)改變的,這個(gè)將會(huì)導(dǎo)致和預(yù)期不同的結(jié)果。
Update 2020/3/17:
評(píng)論區(qū)的Oshrin朋友提出問(wèn)題
??? 作者您好,寫的很好,但是是否存在問(wèn)題。即使將track_running_stats設(shè)置為False,如果momentum不為None的話,還是會(huì)用滑動(dòng)平均來(lái)計(jì)算running_mean和running_var的,而非是僅僅使用本batch的數(shù)據(jù)情況。而且關(guān)于凍結(jié)bn層,有一些更好的方法。
這里的momentum的作用,按照文檔,這個(gè)參數(shù)是在對(duì)統(tǒng)計(jì)參數(shù)進(jìn)行更新過(guò)程中,進(jìn)行指數(shù)平滑使用的,比如統(tǒng)計(jì)參數(shù)的更新策略將會(huì)變成:
其中的更新后的統(tǒng)計(jì)參數(shù) x ^ n e w \hat{x}_{\mathrm{new}} x^new?,是根據(jù)當(dāng)前觀察 x t x_t xt?和歷史觀察 x ^ \hat{x} x^進(jìn)行加權(quán)平均得到的(差分的加權(quán)平均相當(dāng)于歷史序列的指數(shù)平滑),默認(rèn)的momentum=0.1。然而跟蹤歷史信息并且更新的這個(gè)行為是基于track_running_stats為true并且training=true的情況同時(shí)成立的時(shí)候,才會(huì)進(jìn)行的,當(dāng)在track_running_stats=true, training=false時(shí)(在默認(rèn)的model.eval()情況下,即是之前談到的四種組合的第三個(gè),既滿足這種情況),將不涉及到統(tǒng)計(jì)參數(shù)的指數(shù)滑動(dòng)更新了。[12,13]
這里引用一個(gè)不錯(cuò)的BN層凍結(jié)的例子,如:[14]
import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *
def fix_bn(m):
??? classname = m.__class__.__name__
??? if classname.find('BatchNorm') != -1:
??????? m.eval()
model = models.resnet50(pretrained=True)
model.cuda()
model = network(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()
總結(jié)來(lái)說(shuō),在某些情況下,即便整體的模型處于model.train()的狀態(tài),但是某些BN層也可能需要按照需求設(shè)置為model_bn.eval()的狀態(tài)。
Update 2020.6.19:
評(píng)論區(qū)有個(gè)同學(xué)問(wèn)了一個(gè)問(wèn)題:
??? K.G.lee:想問(wèn)博主,為什么模型測(cè)試時(shí)的參數(shù)為trainning=False, track_running_stats=True啊??測(cè)試不是用訓(xùn)練時(shí)的滑動(dòng)平均值嗎?為什么track_running_stats=True呢?為啥要跟蹤當(dāng)前batch??
我感覺這個(gè)問(wèn)題問(wèn)得挺好的,我們需要去翻下源碼[15],我們發(fā)現(xiàn)我們所有的BatchNorm層都有個(gè)共同的父類_BatchNorm,我們最需要關(guān)注的是return F.batch_norm()這一段,我們發(fā)現(xiàn),其對(duì)training的判斷邏輯是
training=self.training or not self.track_running_stats
那么,其實(shí)其在eval階段,這里的track_running_stats并不能設(shè)置為False,原因很簡(jiǎn)單,這樣會(huì)使得上面談到的training=True,導(dǎo)致最終的期望程序錯(cuò)誤。至于設(shè)置了track_running_stats=True是不是會(huì)導(dǎo)致在eval階段跟蹤測(cè)試集的batch的統(tǒng)計(jì)參數(shù)呢?我覺得是不會(huì)的,我們追蹤會(huì)發(fā)現(xiàn)[16],整個(gè)流程的最后一步其實(shí)是調(diào)用了torch.batch_norm(),其是調(diào)用C++的底層函數(shù),其參數(shù)列表可和track_running_stats一點(diǎn)關(guān)系都沒有,只是由training控制,因此當(dāng)training=False時(shí),其不會(huì)跟蹤統(tǒng)計(jì)參數(shù)的,只是會(huì)調(diào)用訓(xùn)練集訓(xùn)練得到的統(tǒng)計(jì)參數(shù)。(當(dāng)然,時(shí)間有限,我也沒有繼續(xù)追到C++層次去看源碼了)。
class _BatchNorm(_NormBase):
??? def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
???????????????? track_running_stats=True):
??????? super(_BatchNorm, self).__init__(
??????????? num_features, eps, momentum, affine, track_running_stats)
??? def forward(self, input):
??????? self._check_input_dim(input)
??????? # exponential_average_factor is set to self.momentum
??????? # (when it is available) only so that it gets updated
??????? # in ONNX graph when this node is exported to ONNX.
??????? if self.momentum is None:
??????????? exponential_average_factor = 0.0
??????? else:
??????????? exponential_average_factor = self.momentum
??????? if self.training and self.track_running_stats:
??????????? # TODO: if statement only here to tell the jit to skip emitting this when it is None
??????????? if self.num_batches_tracked is not None:
??????????????? self.num_batches_tracked = self.num_batches_tracked + 1
??????????????? if self.momentum is None:? # use cumulative moving average
??????????????????? exponential_average_factor = 1.0 / float(self.num_batches_tracked)
??????????????? else:? # use exponential moving average
??????????????????? exponential_average_factor = self.momentum
??????? return F.batch_norm(
??????????? input, self.running_mean, self.running_var, self.weight, self.bias,
??????????? self.training or not self.track_running_stats,
??????????? exponential_average_factor, self.eps)
?? def batch_norm(input, running_mean, running_var, weight=None, bias=None,
?????????????? training=False, momentum=0.1, eps=1e-5):
??? # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor? # noqa
??? r"""Applies Batch Normalization for each channel across a batch of data.
??? See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
??? :class:`~torch.nn.BatchNorm3d` for details.
??? """
??? if not torch.jit.is_scripting():
??????? if type(input) is not Tensor and has_torch_function((input,)):
??????????? return handle_torch_function(
??????????????? batch_norm, (input,), input, running_mean, running_var, weight=weight,
??????????????? bias=bias, training=training, momentum=momentum, eps=eps)
??? if training:
??????? _verify_batch_size(input.size())
??? return torch.batch_norm(
??????? input, weight, bias, running_mean, running_var,
??????? training, momentum, eps, torch.backends.cudnn.enabled
??? )
??
Reference
[1]. 用pytorch踩過(guò)的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度學(xué)習(xí)優(yōu)化策略-1>Batch Normalization(BN)
[4]. 詳解深度學(xué)習(xí)中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的參數(shù)track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白話《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》
[12]. https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146/2
[13]. https://zhuanlan.zhihu.com/p/65439075
[14]. https://github.com/NVIDIA/apex/issues/122
[15]. https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
[16]. https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#batch_norm
————————————————
版權(quán)聲明:本文為CSDN博主「FesianXu」的原創(chuàng)文章,遵循CC 4.0 BY-SA版權(quán)協(xié)議,轉(zhuǎn)載請(qǐng)附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/LoseInVain/article/details/86476010
總結(jié)
以上是生活随笔為你收集整理的Pytorch的BatchNorm层使用中容易出现的问题的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Himall商城LinqHelper帮助
- 下一篇: linux pid t 头文件_linu