??Manifold Mixup和PatchUp是對mixup數據增強算法的兩種改進方法,作者都來自Yoshua Bengio團隊。這兩種方法都是mixup方法在中間隱層的推廣,因此原文開源代碼都需要對網絡各層的內部代碼進行修改,使用起來并不方便,不能做到即插即用。我用pytorch中的鉤子方法(hook)對這兩個方法進行重新實現,這樣就可以實現即插即用,方便的應用到各種網絡結構中,而且我實現的代碼比原開源代碼速度還能提高60%左右。
Manifold Mixup 論文:https://arxiv.org/abs/1806.05236
Manifold Mixup 官方開源:https://github.com/vikasverma1077/manifold_mixup
PatchUp 論文:https://arxiv.org/abs/2006.07794
PatchUp 官方開源:https://github.com/chandar-lab/PatchUp
一、Manifold Mixup簡介及代碼
??manifold mixup是對mixup的擴展,把輸入數據(raw input data)混合擴展到對中間隱層輸出混合。至于對中間隱層混合更有效的原因,作者的解釋比較深奧。首先給出了現象級的解釋,即這種混合帶來了三個優勢:平滑決策邊界、拉大低置信空間(拉開各類別高置信空間的間距)、展平隱層輸出的數值。至于這三點為什么有效,從作者說法看這應該是一種業界共識。然后作者又從數學上分析了第三點,即為什么manifold mixup可以實現展平中間隱層輸出。
??由于需要修改網絡中間層的輸出張量,如果不修改網絡內部,也可以使用鉤子操作(hook)在外部進行。核心部分代碼如下:
import torch
import torch
.nn as nn
import torch
.nn
.functional as F
import numpy as npdef
to_one_hot(inp
, num_classes
):y_onehot
= torch
.FloatTensor(inp
.size(0), num_classes
).to(inp
.device
)y_onehot
.zero_()y_onehot
.scatter_(1, inp
.unsqueeze(1).data
, 1)return y_onehotbce_loss
= nn
.BCELoss()
softmax
= nn
.Softmax(dim
=1)class
ManifoldMixupModel(nn
.Module
):def
__init__(self
, model
, num_classes
= 10, alpha
= 1):super().__init__()self
.model
= modelself
.alpha
= alphaself
.lam
= Noneself
.num_classes
= num_classes##選擇需要操作的層,在ResNet中各block的層名為layer1
,layer2
...所以可以寫成如下。其他網絡請自行修改self
. module_list
= []for n
,m in self
.model
.named_modules():#if 'conv' in n:if n
[:-1]=='layer':self
.module_list
.append(m
)def
forward(self
, x
, target
=None
):if target
==None
:out
= self
.model(x
)return out
else:if self
.alpha
<= 0:self
.lam
= 1else:self
.lam
= np
.random
.beta(self
.alpha
, self
.alpha
)k
= np
.random
.randint(-1, len(self
.module_list
))self
.indices
= torch
.randperm(target
.size(0)).cuda()target_onehot
= to_one_hot(target
, self
.num_classes
)target_shuffled_onehot
= target_onehot
[self
.indices
]if k
== -1:x
= x
* self
.lam
+ x
[self
.indices
] * (1 - self
.lam
)out
= self
.model(x
)else:modifier_hook
= self
.module_list
[k
].register_forward_hook(self
.hook_modify
)out
= self
.model(x
)modifier_hook
.remove()target_reweighted
= target_onehot
* self
.lam
+ target_shuffled_onehot
* (1 - self
.lam
)loss
= bce_loss(softmax(out
), target_reweighted
)return out
, lossdef
hook_modify(self
, module
, input
, output
):output
= self
.lam
* output
+ (1 - self
.lam
) * output
[self
.indices
]return output
調用代碼如下:
net
= ResNet18()
net
= ManifoldMixupModel(net
,num_classes
=10, alpha
=args
.alpha
)
def
train(epoch
):net
.train()for batch_idx
, (inputs
, targets
) in
enumerate(trainloader
):inputs
, targets
= inputs
.cuda(), targets
.cuda()outputs
, loss
= net(inputs
, targets
)optimizer
.zero_grad()loss
.backward()optimizer
.step()def
test(epoch
):net
.eval()with torch
.no_grad():for batch_idx
, (inputs
, targets
) in
enumerate(testloader
):inputs
, targets
= inputs
.cuda(), targets
.cuda()outputs
= net(inputs
)
二、PatchUp簡介及代碼
??PatchUp方法在manifold mixup基礎上,又借鑒了cutMix在空間維度剪裁的思路,對中間隱層輸出也進行剪裁,對兩個不同樣本的中間隱層剪裁塊(patches)進行互換或插值,文中稱互換法為硬patchUp,插值法為軟patchUp。試驗發現互換法在識別精度上更好,插值法在對抗攻擊的魯棒性上更好。這篇論文中沒有對方法理論進行深度解釋,僅僅給出了一個現象級對比,就是patchUp方法的隱層激活值比較高。
??使用hook實現的核心代碼PatchUpModel類如下,注意在該代碼中強制k=-1就可以變成CutMix:
class
PatchUpModel(nn
.Module
):def
__init__(self
, model
, num_classes
= 10, block_size
=7, gamma
=.9, patchup_type
='hard',keep_prob
=.9):super().__init__()self
.patchup_type
= patchup_typeself
.block_size
= block_sizeself
.gamma
= gammaself
.gamma_adj
= Noneself
.kernel_size
= (block_size
, block_size
)self
.stride
= (1, 1)self
.padding
= (block_size self
.computed_lam
= Noneself
.model
= modelself
.num_classes
= num_classesself
. module_list
= []for n
,m in self
.model
.named_modules():if n
[:-1]=='layer':#if 'conv' in n:self
.module_list
.append(m
)def
adjust_gamma(self
, x
):return self
.gamma
* x
.shape
[-1] ** 2 / \
(self
.block_size
** 2 * (x
.shape
[-1] - self
.block_size
+ 1) ** 2)def
forward(self
, x
, target
=None
):if target
==None
:out
= self
.model(x
)return out
else:self
.lam
= np
.random
.beta(2.0, 2.0)k
= np
.random
.randint(-1, len(self
.module_list
))self
.indices
= torch
.randperm(target
.size(0)).cuda()self
.target_onehot
= to_one_hot(target
, self
.num_classes
)self
.target_shuffled_onehot
= self
.target_onehot
[self
.indices
]if k
== -1: #CutMixW
,H
= x
.size(2),x
.size(3)cut_rat
= np
.sqrt(1. - self
.lam
)cut_w
= np
.int(W
* cut_rat
)cut_h
= np
.int(H
* cut_rat
)cx
= np
.random
.randint(W
)cy
= np
.random
.randint(H
)bbx1
= np
.clip(cx
- cut_w bby1
= np
.clip(cy
- cut_h bbx2
= np
.clip(cx
+ cut_w bby2
= np
.clip(cy
+ cut_h x
[:, :, bbx1
:bbx2
, bby1
:bby2
] = x
[self
.indices
, :, bbx1
:bbx2
, bby1
:bby2
]lam
= 1 - ((bbx2
- bbx1
) * (bby2
- bby1
) / (W
* H
))out
= self
.model(x
)loss
= bce_loss(softmax(out
), self
.target_onehot
) * lam
+\
bce_loss(softmax(out
), self
.target_shuffled_onehot
) * (1. - lam
)else:modifier_hook
= self
.module_list
[k
].register_forward_hook(self
.hook_modify
)out
= self
.model(x
)modifier_hook
.remove()loss
= 1.0 * bce_loss(softmax(out
), self
.target_a
) * self
.total_unchanged_portion
+ \
bce_loss(softmax(out
), self
.target_b
) * (1. - self
.total_unchanged_portion
) + \
1.0 * bce_loss(softmax(out
), self
.target_reweighted
)return out
, lossdef
hook_modify(self
, module
, input
, output
):self
.gamma_adj
= self
.adjust_gamma(output
)p
= torch
.ones_like(output
[0]) * self
.gamma_adjm_i_j
= torch
.bernoulli(p
)mask_shape
= len(m_i_j
.shape
)m_i_j
= m_i_j
.expand(output
.size(0), m_i_j
.size(0), m_i_j
.size(1), m_i_j
.size(2))holes
= F
.max_pool2d(m_i_j
, self
.kernel_size
, self
.stride
, self
.padding
)mask
= 1 - holesunchanged
= mask
* output
if mask_shape
== 1:total_feats
= output
.size(1)else:total_feats
= output
.size(1) * (output
.size(2) ** 2)total_changed_pixels
= holes
[0].sum()total_changed_portion
= total_changed_pixels
/ total_featsself
.total_unchanged_portion
= (total_feats
- total_changed_pixels
) / total_feats
if self
.patchup_type
== 'hard':self
.target_reweighted
= self
.total_unchanged_portion
* self
.target_onehot
+\total_changed_portion
* self
.target_shuffled_onehotpatches
= holes
* output
[self
.indices
]self
.target_b
= self
.target_onehot
[self
.indices
]elif self
.patchup_type
== 'soft':self
.target_reweighted
= self
.total_unchanged_portion
* self
.target_onehot
+\self
.lam
* total_changed_portion
* self
.target_onehot
+\
(1 - self
.lam
) * total_changed_portion
* self
.target_shuffled_onehotpatches
= holes
* outputpatches
= patches
* self
.lam
+ patches
[self
.indices
] * (1 - self
.lam
)self
.target_b
= self
.lam
* self
.target_onehot
+ (1 - self
.lam
) * self
.target_shuffled_onehot
else:raise
ValueError("patchup_type must be \'hard\' or \'soft\'.")output
= unchanged
+ patchesself
.target_a
= self
.target_onehot
return output
??調用過程同上,其中模型包裝語句如下:
net
= ResNet18()
net
= PatchUpModel(net
,num_classes
=10, block_size
=7, gamma
=.9, patchup_type
='hard')
三、在CIFAR-10上試驗結果
??試驗主要目的是驗證代碼可運行。僅靠在一個簡單數據集上一次試驗非常不充分,不能公平對比效果,所以不作為各方法的性能對比。
總結
以上是生活随笔為你收集整理的[论文学习]Manifold Mixup和PatchUp的代码重新实现(实现即插即用且速度更快)的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。