Repoptimizer论文理解与代码分析
上一篇介紹了RepVGG,RepVGG存在量化問題,Repopt通過將先驗(yàn)融入優(yōu)化器中,統(tǒng)一訓(xùn)練與測試模型解決了其量化不友好的問題。
論文鏈接: Re-parameterizing Your Optimizers rather than Architectures
Introduction
Repopt提出將模型結(jié)構(gòu)的先驗(yàn)信息直接用于修改梯度數(shù)值,其稱為梯度重參數(shù)化,對應(yīng)的優(yōu)化器稱為RepOptimizer。Repopt著重關(guān)注VGG式的直筒模型,訓(xùn)練得到RepOptVGG模型與VGG結(jié)構(gòu)一致,有著高訓(xùn)練效率,簡單直接的結(jié)構(gòu)和極快的推理速度。
與RepVGG的不同:
1)RepVGG在訓(xùn)練過程中加入了結(jié)構(gòu)先驗(yàn)(shortcut,1x1 branch),在推理時,將多支路融合成單路3x3卷積。而RepOptVGG將結(jié)構(gòu)先驗(yàn)轉(zhuǎn)移至梯度中,通過設(shè)計(jì)的RepOpt優(yōu)化器實(shí)現(xiàn)。
2)在結(jié)構(gòu)上,RepOptVGG是真-直筒結(jié)構(gòu),模型在訓(xùn)練與測試時保持一致。RepVGG訓(xùn)練時存在多支路需要更多的顯存與訓(xùn)練時間。
3)RepOptVGG通過定制優(yōu)化器,實(shí)現(xiàn)了結(jié)構(gòu)重參與梯度重參的等效變化。
Idea
Repopt發(fā)現(xiàn)結(jié)構(gòu)先驗(yàn)的一個有趣現(xiàn)象:當(dāng)每個分支只包含一個線性可訓(xùn)練算子,如果正確設(shè)置常尺度值,模型的性能會提高。我們將這種線性塊稱為Constant Scale Linear Addition(CSLA)。我們可以用單個算子替換一個CSLA塊,并通過設(shè)計(jì)優(yōu)化器改變梯度實(shí)現(xiàn)等價的訓(xùn)練動態(tài)。Repopt將這種乘數(shù)稱為Grad Mult,如上圖所示。
證明:用常規(guī)的SGD訓(xùn)練一個CSLA塊相當(dāng)于用修改的梯度訓(xùn)練一個簡單的卷積
CSLA塊中每個分支只包含一個可訓(xùn)練線性算子,并且結(jié)構(gòu)中不存在BN或者dropout等非線性操作。Repopt發(fā)現(xiàn)用常規(guī)的SGD訓(xùn)練一個CSLA塊相當(dāng)于用修改的梯度訓(xùn)練一個簡單的卷積。下面用一個簡單的例子證明這個結(jié)論。
假設(shè)CSLA由兩個相同形狀的卷積組成,其中每個核包含一個可訓(xùn)練線性算子。如下面公式所示,其中αA,αB\alpha_A,\alpha_BαA?,αB?為可訓(xùn)練線性算子,W為卷積的參數(shù),X是輸入,Y為CSLA的輸出,*表示卷積操作。
對應(yīng)的梯度重參公式YGR=X?W′Y_{GR}=X*W^{\prime}YGR?=X?W′,其中W′W^{\prime}W′表示梯度重參后的卷積,假設(shè)損失函數(shù)為L,訓(xùn)練迭代數(shù)為i,卷積參數(shù)W的梯度表示為?L?W\frac{\partial L}{\partial W}?W?L?,F(?L?W′)F(\frac{\partial L}{\partial W^{\prime}})F(?W′?L?)表示對應(yīng)梯度重參上的任意變化,我們希望通過數(shù)次訓(xùn)練后CSLA的輸出與梯度重參后的輸出一致,即
通過卷積的線性可加性,我們需要保證公式6
在i=0迭代開始前,正確的初始化確保了公式6的等價性,初始化如公式7所示
下面,我們用數(shù)學(xué)歸納法證明在W′W^{\prime}W′的梯度上進(jìn)行適當(dāng)?shù)淖儞Q后,公式6的等價性始終成立,W梯度更新的公式如下
更新相應(yīng)的CSLA塊,我們獲得公式10
我們使用F(?L?W′)F(\frac{\partial L}{\partial W^{\prime}})F(?W′?L?)來更新W′W^{\prime}W′,這就意味著
假設(shè)在迭代第i次時,公式6,10,11成立,那么可以獲得公式12
對公式6取偏導(dǎo)數(shù),有公式13
我們獲得等式14,即F(?L?W′)F(\frac{\partial L}{\partial W^{\prime}})F(?W′?L?)的準(zhǔn)確形式
由公式11,14,我們可以推到出,當(dāng)?shù)絠+1次時,下面等式成立
由于假設(shè)公式6成立
通過初始條件公式7,8,以及數(shù)學(xué)歸納法我們可以證明當(dāng)i>=0時,公式6成立。同時,我們知道F(?L?W′)F(\frac{\partial L}{\partial W^{\prime}})F(?W′?L?)的準(zhǔn)確形式,如公式14所示。
Method
上文,已經(jīng)介紹了Repopt找到一個合適的結(jié)構(gòu)先驗(yàn)CSLA塊,并通過數(shù)學(xué)歸納證明可以通過梯度重參將CSLA等效為簡單的卷積操作,下面,我們使用RepOpt-VGG作為展示例,具體介紹Repopt如何設(shè)計(jì)和描述梯度重參的行為。
在RepOptVGG中,對應(yīng)的CSLA塊則是將RepVGG塊中的3x3卷積,1x1卷積,bn層替換為帶可學(xué)習(xí)縮放參數(shù)的3x3卷積,1x1卷積。進(jìn)一步拓展到多分支中,假設(shè)s,t分別是3x3卷積,1x1卷積的縮放系數(shù),那么對應(yīng)的更新規(guī)則為:
對公式3的理解需要結(jié)合RepVGG,當(dāng)輸入與輸出通道不等時,只存在conv3x3, conv1x1兩個分支,其中conv1x1可以等效為特殊的conv3x3,因此梯度可以重參為sc2+tc2s_c^2+t_c^2sc2?+tc2?,如上文所證明一樣。而當(dāng)輸入與輸出通道相等時,此時一共有3個分支,分別是identity,conv3x3, conv1x1,Identity也可以等效為特殊的conv3x3,其卷積核由0,1組成,所以梯度重參為1+sc2+tc21+s_c^2+t_c^21+sc2?+tc2?。
需要注意的是CSLA沒有BN這種訓(xùn)練期間非線性算子(training-time nonlinearity),也沒有非順序性(non sequential)可訓(xùn)練參數(shù),CSLA在這里只是一個描述RepOptimizer的間接工具。
那么剩下一個問題,即如何確定這個縮放系數(shù)
HyperSearch
受DARTS啟發(fā),我們將CSLA中的常數(shù)縮放系數(shù),替換成可訓(xùn)練參數(shù)。在一個小數(shù)據(jù)集(如CIFAR100)上進(jìn)行訓(xùn)練,在小數(shù)據(jù)上訓(xùn)練完畢后,我們將這些可訓(xùn)練參數(shù)固定為常數(shù)。
Code
LinearAddBlock定義的是CSLA塊,該模塊只在確定HyperSearch的時候被訓(xùn)練。
class LinearAddBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,dilation=1, groups=1, padding_mode='zeros', use_se=False, is_csla=False, conv_scale_init=1.0):super(LinearAddBlock, self).__init__()self.in_channels = in_channelsself.relu = nn.ReLU()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)self.scale_conv = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=conv_scale_init)self.conv_1x1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)self.scale_1x1 = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=conv_scale_init)if in_channels == out_channels and stride == 1:self.scale_identity = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=1.0)self.bn = nn.BatchNorm2d(out_channels)if is_csla: # Make them constantself.scale_1x1.requires_grad_(False)self.scale_conv.requires_grad_(False)if use_se:raise NotImplementedError("se block not supported yet")else:self.se = nn.Identity()def forward(self, inputs):out = self.scale_conv(self.conv(inputs)) + self.scale_1x1(self.conv_1x1(inputs))if hasattr(self, 'scale_identity'):out += self.scale_identity(inputs)out = self.relu(self.se(self.bn(out)))return outclass ScaleLayer(torch.nn.Module):def __init__(self, num_features, use_bias=True, scale_init=1.0):super(ScaleLayer, self).__init__()self.weight = Parameter(torch.Tensor(num_features))init.constant_(self.weight, scale_init)self.num_features = num_featuresif use_bias:self.bias = Parameter(torch.Tensor(num_features))init.zeros_(self.bias)else:self.bias = Nonedef forward(self, inputs):if self.bias is None:return inputs * self.weight.view(1, self.num_features, 1, 1)else:return inputs * self.weight.view(1, self.num_features, 1, 1) + self.bias.view(1, self.num_features, 1, 1)RealVGGBlock是RepOptVGG的真實(shí)模塊,結(jié)構(gòu)簡單如下所示。
class RealVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,dilation=1, groups=1, padding_mode='zeros', use_se=False,):super(RealVGGBlock, self).__init__()self.relu = nn.ReLU()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)self.bn = nn.BatchNorm2d(out_channels)if use_se:raise NotImplementedError("se block not supported yet")else:self.se = nn.Identity()def forward(self, inputs):out = self.relu(self.se(self.bn(self.conv(inputs))))return out假設(shè)我們已經(jīng)通過小數(shù)據(jù)訓(xùn)練獲得了HyperSearch需要的scales,那么在訓(xùn)練RepOptVGG時,RepVGGOptimizer需要在初始化時候?qū)SLA塊的scales賦值給RealVGGBlock,賦值的過程如reinitialize所示,對應(yīng)了Method中的公式3。
def reinitialize(self, scales_by_idx, conv3x3_by_idx, use_identity_scales):for scales, conv3x3 in zip(scales_by_idx, conv3x3_by_idx):in_channels = conv3x3.in_channelsout_channels = conv3x3.out_channelskernel_1x1 = nn.Conv2d(in_channels, out_channels, 1, device=conv3x3.weight.device)if len(scales) == 2:conv3x3.weight.data = conv3x3.weight * scales[1].view(-1, 1, 1, 1) \+ F.pad(kernel_1x1.weight, [1, 1, 1, 1]) * scales[0].view(-1, 1, 1, 1)else:assert len(scales) == 3assert in_channels == out_channelsidentity = torch.from_numpy(np.eye(out_channels, dtype=np.float32).reshape(out_channels, out_channels, 1, 1)).to(conv3x3.weight.device)conv3x3.weight.data = conv3x3.weight * scales[2].view(-1, 1, 1, 1) + F.pad(kernel_1x1.weight, [1, 1, 1, 1]) * scales[1].view(-1, 1, 1, 1)if use_identity_scales: # You may initialize the imaginary CSLA block with the trained identity_scale values. Makes almost no difference.identity_scale_weight = scales[0]conv3x3.weight.data += F.pad(identity * identity_scale_weight.view(-1, 1, 1, 1), [1, 1, 1, 1])else:conv3x3.weight.data += F.pad(identity, [1, 1, 1, 1])我們在梯度重參過程中需要獲取梯度Mask,與reinitialize過程相似分為3種情況,具體實(shí)現(xiàn)如下所示。
def generate_gradient_masks(self, scales_by_idx, conv3x3_by_idx, cpu_mode=False):self.grad_mask_map = {}for scales, conv3x3 in zip(scales_by_idx, conv3x3_by_idx):para = conv3x3.weightif len(scales) == 2:mask = torch.ones_like(para, device=scales[0].device) * (scales[1] ** 2).view(-1, 1, 1, 1)mask[:, :, 1:2, 1:2] += torch.ones(para.shape[0], para.shape[1], 1, 1, device=scales[0].device) * (scales[0] ** 2).view(-1, 1, 1, 1)else:mask = torch.ones_like(para, device=scales[0].device) * (scales[2] ** 2).view(-1, 1, 1, 1)mask[:, :, 1:2, 1:2] += torch.ones(para.shape[0], para.shape[1], 1, 1, device=scales[0].device) * (scales[1] ** 2).view(-1, 1, 1, 1)ids = np.arange(para.shape[1])assert para.shape[1] == para.shape[0]mask[ids, ids, 1:2, 1:2] += 1.0if cpu_mode:self.grad_mask_map[para] = maskelse:self.grad_mask_map[para] = mask.cuda()通過Repopt梯度重參的方式將結(jié)構(gòu)先驗(yàn)轉(zhuǎn)化為梯度先驗(yàn),可以統(tǒng)一訓(xùn)練與測試模型結(jié)構(gòu),有效解決RepVGG量化不友好問題,其結(jié)構(gòu)在YOLOV6中被使用,并表現(xiàn)出極佳的性能。
總結(jié)
以上是生活随笔為你收集整理的Repoptimizer论文理解与代码分析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 优秀的linux学习网站
- 下一篇: 几个冷门linux与BSD发行版中文学习