ADMM算法在神经网络模型剪枝方面的应用
文章目錄
- 前言
- 1. 交替方向乘子法
- 2. 論文中的表述
- 3. 對(duì)論文中的公式進(jìn)行推導(dǎo)
- 4. 代碼流程
- 5. 主要函數(shù)實(shí)現(xiàn)
- 6. dense vs. prune(finetune)
- 結(jié)束語(yǔ)
前言
??本篇博客記錄一下自己根據(jù)對(duì)論文 GRIM: A General, Real-Time Deep Learning Inference Framework for Mobile Devices based on Fine-Grained Structured Weight Sparsity 中提到的ADMM算法的理解,給出了ADMM算法的推導(dǎo)過(guò)程,并在文章的末尾提供了實(shí)現(xiàn)的代碼。
1. 交替方向乘子法
??交替方向乘子法(Alternating Direction Method of Multipliers, ADMM)作為一種求解優(yōu)化問(wèn)題的計(jì)算框架,適用于求解凸優(yōu)化問(wèn)題。ADMM算法的思想根源可以追溯到20世紀(jì)50年代,在20世紀(jì)八九十年代中期存在大量的文章分析這種方法的性質(zhì),但是當(dāng)時(shí)ADMM主要用于解決偏微分方程問(wèn)題。1970年由 R. Glowinski 和 D. Gabay 等提出的一種適用于可分離凸優(yōu)化的簡(jiǎn)單有效方法,并在統(tǒng)計(jì)機(jī)器學(xué)習(xí)、數(shù)據(jù)挖掘和計(jì)算機(jī)視覺(jué)等領(lǐng)域中得到了廣泛應(yīng)用。ADMM算法主要解決帶有等式約束的關(guān)于兩個(gè)變量的目標(biāo)函數(shù)的最小化問(wèn)題,可以看作在增廣拉朗格朗日算法基礎(chǔ)上發(fā)展的算法,混合了對(duì)偶上升算法(Dual Ascent)的可分解性和乘子法(Method of Multipliers)的算法優(yōu)越的收斂性。相對(duì)于乘子法,ADMM算法最大的優(yōu)勢(shì)在于其能夠充分利用目標(biāo)函數(shù)的可分解性,對(duì)目標(biāo)函數(shù)中的多變量進(jìn)行交替優(yōu)化。在解決大規(guī)模問(wèn)題上,利用ADMM算法可以將原問(wèn)題的目標(biāo)函數(shù)等價(jià)地分解成若干個(gè)可求解的子問(wèn)題,然后并行求解每一個(gè)子問(wèn)題,最后協(xié)調(diào)子問(wèn)題的解得到原問(wèn)題的全局解。1
??優(yōu)化問(wèn)題
minimizef(x)+g(z)subjecttoAx+Bz=cminimize\ f(x)+g(z) \\ subject\ to\ Ax+Bz=cminimize?f(x)+g(z)subject?to?Ax+Bz=c??其中,x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rpx \in R^n,z \in R^m,A \in R^{p \times n},B \in R^{p \times m},c \in R^px∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rp,構(gòu)造拉格朗日函數(shù)為
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c)Lp?(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)??其增廣拉格朗日函數(shù)(augmented Lagrangian function)為
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)+ρ2∣∣Ax+Bz?c∣∣2L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c)+ \frac {\rho} {2}||Ax+Bz-c||^{2}Lp?(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)+2ρ?∣∣Ax+Bz?c∣∣2??對(duì)偶上升法迭代更新
(xk+1,zk+1)=argminx,zLp(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1?c)(x^{k+1},z^{k+1})=\underset {x,z} {argmin\ } L_p(x,z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c)(xk+1,zk+1)=x,zargmin??Lp?(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1?c)??交替方向乘子法則是在(x,z)(x,z)(x,z)一起迭代的基礎(chǔ)上將x,zx,zx,z分別固定單獨(dú)交替迭代,即
xk+1=argminxLp(x,zk,λk)zk+1=argminzLp(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1?c)x^{k+1}=\underset {x} {argmin\ }L_p(x,z^k,\lambda ^k) \\ z^{k+1}=\underset {z} {argmin\ }L_p(x^{k+1},z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c)xk+1=xargmin??Lp?(x,zk,λk)zk+1=zargmin??Lp?(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1?c)??交替方向乘子的另一種等價(jià)形式,將殘差定義為rk=Axk+Bzk?cr^k=Ax^k+Bz^k-crk=Axk+Bzk?c,同時(shí)定義uk=1ρλku^k=\frac {1} {\rho} \lambda ^kuk=ρ1?λk作為縮放的對(duì)偶變量(dual variable),有
(λk)Trk+ρ2∣∣rk∣∣2=ρ2∣∣rk+uk∣∣2?ρ2∣∣uk∣∣2(\lambda ^k)^Tr^k+\frac {\rho} {2} ||r^k||^2=\frac {\rho} {2}||r^k+u^k||^2-\frac {\rho} {2}||u^k||^2(λk)Trk+2ρ?∣∣rk∣∣2=2ρ?∣∣rk+uk∣∣2?2ρ?∣∣uk∣∣2??改寫 ADMM 的迭代過(guò)程
xk+1=argminx{f(x)+ρ2∣∣Ax+Bzk?c+uk∣∣2}zk+1=argminz{g(z)+ρ2∣∣Axk+1+Bz?c+uk∣∣2}uk+1=uk+Axk+1+Bzk+1?cx^{k+1} =\underset {x} {argmin\ }\bigg\{f(x)+\frac {\rho} {2}||Ax+Bz^k-c+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{g(z)+\frac {\rho} {2}||Ax^{k+1}+Bz-c+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+Ax^{k+1}+Bz^{k+1}-c xk+1=xargmin??{f(x)+2ρ?∣∣Ax+Bzk?c+uk∣∣2}zk+1=zargmin??{g(z)+2ρ?∣∣Axk+1+Bz?c+uk∣∣2}uk+1=uk+Axk+1+Bzk+1?c
2. 論文中的表述
3. 對(duì)論文中的公式進(jìn)行推導(dǎo)
??為便于推導(dǎo)公式,將論文中的進(jìn)行簡(jiǎn)化,參數(shù)W和b簡(jiǎn)記為W,此時(shí)的優(yōu)化問(wèn)題變?yōu)?br /> minimizef(Wi)+∑i=1Ng(Zi)subjecttoWi=Zi,i=1,2,...,Nminimize\ f(W_i)+\sum_{i=1}^{N} g(Z_i) \\[4pt] subject\ to\ W_i=Z_i, i=1,2,...,Nminimize?f(Wi?)+i=1∑N?g(Zi?)subject?to?Wi?=Zi?,i=1,2,...,N??構(gòu)造拉格朗日函數(shù)為
Lp(w,z,λ)=f(w)+∑g(z)+λT(w?z)L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z)Lp?(w,z,λ)=f(w)+∑g(z)+λT(w?z)??其增廣拉格朗日函數(shù)為
Lp(w,z,λ)=f(w)+∑g(z)+λT(w?z)+∑ρ2∣∣w?z∣∣2L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z)+ \sum \frac {\rho} {2}||w-z||^{2}Lp?(w,z,λ)=f(w)+∑g(z)+λT(w?z)+∑2ρ?∣∣w?z∣∣2??交替方向乘子法:在(x, z)一起迭代的基礎(chǔ)上將 x, z 分別固定,單獨(dú)交替迭代,即
wk+1=argminwLp(w,zk,λk)zk+1=argminzLp(wk+1,z,λk)λk+1=λk+∑ρ(w?z)w^{k+1}=\underset {w} {argmin\ }L_p(w,z^k,\lambda ^k) \\[4pt] z^{k+1}=\underset {z} {argmin\ }L_p(w^{k+1},z,\lambda ^k) \\[4pt] \lambda ^{k+1}=\lambda ^k+\sum \rho (w-z)wk+1=wargmin??Lp?(w,zk,λk)zk+1=zargmin??Lp?(wk+1,z,λk)λk+1=λk+∑ρ(w?z)??定義一個(gè)對(duì)偶變量
uk=1ρλku^k=\frac {1} {\rho} \lambda ^kuk=ρ1?λk??改寫ADMM的迭代過(guò)程
wk+1=argminw{f(w)+∑ρ2∣∣w?zk+uk∣∣2}zk+1=argminz{∑g(z)+∑ρ2∣∣wk+1?z+uk∣∣2}uk+1=uk+wk+1?zk+1w^{k+1} =\underset {w} {argmin\ }\bigg\{f(w)+\sum \frac {\rho} {2}||w-z^k+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{\sum g(z)+\sum \frac {\rho} {2}||w^{k+1}-z+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+w^{k+1}-z^{k+1}wk+1=wargmin??{f(w)+∑2ρ?∣∣w?zk+uk∣∣2}zk+1=zargmin??{∑g(z)+∑2ρ?∣∣wk+1?z+uk∣∣2}uk+1=uk+wk+1?zk+1
4. 代碼流程
# 初始化參數(shù)Z和U Z, U = initialize_Z_and_U(model)# 訓(xùn)練model,并更新X,Z,U,損失函數(shù)為admm loss for epoch in range(epochs):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = admm_loss(model, Z, U, output, target)loss.backward()optimizer.step()W = update_W(model)Z = update_Z(W, U, percent)U = update_U(U, W, Z)# 對(duì)weight進(jìn)行剪枝,返回 mask mask = apply_prune(model, percent)# 對(duì)剪枝后的model進(jìn)行finetune finetune(model, mask, train_loader, test_loader, optimizer)5. 主要函數(shù)實(shí)現(xiàn)
def admm_loss(args, device, model, Z, U, output, target):idx = 0loss = F.nll_loss(output, target)for name, param in model.named_parameters():if name.split('.')[-1] == "weight":u = U[idx].to(device)z = Z[idx].to(device)# 這里就是推導(dǎo)出來(lái)的admm的表達(dá)式loss += args.rho / 2 * (param - z + u).norm()return lossdef update_W(model):W = ()for name, param in model.named_parameters():if name.split('.')[-1] == "weight":W += (param.detach().cpu().clone(),)return Wdef update_Z(W, U, args):new_Z = ()idx = 0for w, u in zip(W, U):z = w + upcen = np.percentile(abs(z), 100*args.percent[idx])under_threshold = abs(z) < pcen# percent剪枝率,小于percent分位數(shù)的置為0z.data[under_threshold] = 0new_Z += (z,)idx += 1return new_Zdef update_U(U, W, Z):new_U = ()for u, w, z in zip(U, W, Z):new_u = u + w - znew_U += (new_u,)return new_Udef prune_weight(weight, device, percent):# to work with admm, we calculate percentile based on all elements instead of nonzero elements.weight_numpy = weight.detach().cpu().numpy()pcen = np.percentile(abs(weight_numpy), 100*percent)under_threshold = abs(weight_numpy) < pcen# 非結(jié)構(gòu)化剪枝weight_numpy[under_threshold] = 0mask = torch.Tensor(abs(weight_numpy) >= pcen).to(device)return mask6. dense vs. prune(finetune)
結(jié)束語(yǔ)
??對(duì)論文中算法的推導(dǎo)僅限于自己的理解,可能還存在一些問(wèn)題,歡迎來(lái)評(píng)論區(qū)交流哦^_^
參考教程
《分布式機(jī)器學(xué)習(xí):交替方向乘子法在機(jī)器學(xué)習(xí)中的應(yīng)用》---- 雷大江著 ??
總結(jié)
以上是生活随笔為你收集整理的ADMM算法在神经网络模型剪枝方面的应用的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: python下载安装搭建
- 下一篇: pso算法c++语言代码,一C++PSO