【深度学习】preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选
文章轉自:微信公眾號【機器學習煉丹術】
筆記作者:煉丹兄(已授權轉載)
聯系方式:微信cyx645016617
論文題目:“Masked Autoencoders Are Scalable Vision Learners”
0摘要
本文證明了蒙面自動編碼器(MAE)是一種可擴展的計算機視覺自監督學習器。我們的MAE方法很簡單:我們屏蔽輸入圖像的隨機補丁并重建丟失的像素。
這樣的設計基于兩個core:
我們開發了一種非對稱編碼器-解碼器體系結構,其中的編碼器僅在可見的補丁子集上運行(不帶掩碼),以及一個輕量級解碼器,該解碼器從潛在表示和掩碼令牌重建原始圖像。
其次,我們發現掩蔽高比例的輸入圖像(例如75%)會產生一項不平凡且有意義的自我監督任務。將這兩種設計結合起來,使我們能夠高效地訓練大型模型:我們加快訓練速度(3倍或更多)并提高準確性。
1 方法
從圖片中可以看出,其實模型非常簡單:
是一個類似于VIT的transformer結構,圖像被分成patch,然后其中模型只能看到其中的少部分(25%)的patch,剩下的75%的patch是看不到的;
encoder的輸入是可以看到的25%的patch加上這25%的位置掩碼;
之后通過decoder,來將25%的patches信息還原出來整張圖片,來做重建。
在預訓練之后,解碼器被丟棄,編碼器被應用于未損壞的圖像以產生識別任務的表示。
2 代碼部分-第一步
因為簡單,所以直接看代碼。代碼是由某位大佬自行復現,而非官方!
def?pretrain_mae_small_patch16_224(pretrained=False,?**kwargs):model?=?PretrainVisionTransformer(img_size=224,patch_size=16,encoder_embed_dim=384,encoder_depth=12,encoder_num_heads=6,encoder_num_classes=0,decoder_num_classes=768,decoder_embed_dim=192,decoder_depth=4,decoder_num_heads=3,mlp_ratio=4,qkv_bias=True,norm_layer=partial(nn.LayerNorm,?eps=1e-6),**kwargs)model.default_cfg?=?_cfg()if?pretrained:checkpoint?=?torch.load(kwargs["init_ckpt"],?map_location="cpu")model.load_state_dict(checkpoint["model"])return?model從代碼中的,patch_size,encoder_embed_dim這些參數,不難理解,這個PretrainVisionTransformer是一個經典的VIT的transformer結構(先猜測,后驗證)。
3 代碼部分-第二步
class?PretrainVisionTransformer(nn.Module):"""?Vision?Transformer?with?support?for?patch?or?hybrid?CNN?input?stage"""def?__init__(self,img_size=224,?patch_size=16,?encoder_in_chans=3,?encoder_num_classes=0,?encoder_embed_dim=768,?encoder_depth=12,encoder_num_heads=12,?decoder_num_classes=768,?decoder_embed_dim=512,?decoder_depth=8,decoder_num_heads=8,?mlp_ratio=4.,?qkv_bias=False,?qk_scale=None,?drop_rate=0.,?attn_drop_rate=0.,drop_path_rate=0.,?norm_layer=nn.LayerNorm,?init_values=0.,use_learnable_pos_emb=False,num_classes=0,?#?avoid?the?error?from?create_fn?in?timmin_chans=0,?#?avoid?the?error?from?create_fn?in?timm):super().__init__()self.encoder?=?PretrainVisionTransformerEncoder(img_size=img_size,?patch_size=patch_size,?in_chans=encoder_in_chans,?num_classes=encoder_num_classes,?embed_dim=encoder_embed_dim,?depth=encoder_depth,num_heads=encoder_num_heads,?mlp_ratio=mlp_ratio,?qkv_bias=qkv_bias,?qk_scale=qk_scale,?drop_rate=drop_rate,?attn_drop_rate=attn_drop_rate,drop_path_rate=drop_path_rate,?norm_layer=norm_layer,?init_values=init_values,use_learnable_pos_emb=use_learnable_pos_emb)self.decoder?=?PretrainVisionTransformerDecoder(patch_size=patch_size,?num_patches=self.encoder.patch_embed.num_patches,num_classes=decoder_num_classes,?embed_dim=decoder_embed_dim,?depth=decoder_depth,num_heads=decoder_num_heads,?mlp_ratio=mlp_ratio,?qkv_bias=qkv_bias,?qk_scale=qk_scale,?drop_rate=drop_rate,?attn_drop_rate=attn_drop_rate,drop_path_rate=drop_path_rate,?norm_layer=norm_layer,?init_values=init_values)self.encoder_to_decoder?=?nn.Linear(encoder_embed_dim,?decoder_embed_dim,?bias=False)self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?decoder_embed_dim))self.pos_embed?=?get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches,?decoder_embed_dim)trunc_normal_(self.mask_token,?std=.02)def?_init_weights(self,?m):if?isinstance(m,?nn.Linear):nn.init.xavier_uniform_(m.weight)if?isinstance(m,?nn.Linear)?and?m.bias?is?not?None:nn.init.constant_(m.bias,?0)elif?isinstance(m,?nn.LayerNorm):nn.init.constant_(m.bias,?0)nn.init.constant_(m.weight,?1.0)def?get_num_layers(self):return?len(self.blocks)@torch.jit.ignoredef?no_weight_decay(self):return?{'pos_embed',?'cls_token',?'mask_token'}def?forward(self,?x,?mask):x_vis?=?self.encoder(x,?mask)?#?[B,?N_vis,?C_e]x_vis?=?self.encoder_to_decoder(x_vis)?#?[B,?N_vis,?C_d]B,?N,?C?=?x_vis.shape#?we?don't?unshuffle?the?correct?visible?token?order,?#?but?shuffle?the?pos?embedding?accorddingly.expand_pos_embed?=?self.pos_embed.expand(B,?-1,?-1).type_as(x).to(x.device).clone().detach()pos_emd_vis?=?expand_pos_embed[~mask].reshape(B,?-1,?C)pos_emd_mask?=?expand_pos_embed[mask].reshape(B,?-1,?C)x_full?=?torch.cat([x_vis?+?pos_emd_vis,?self.mask_token?+?pos_emd_mask],?dim=1)x?=?self.decoder(x_full,?pos_emd_mask.shape[1])?#?[B,?N_mask,?3?*?16?*?16]return?x整體來看,是由Encoder和Decoder組成的。我們來把參數羅列一下:
img_size=224
patch_size=16
encoder_in_chans=3
encoder_num_classes=0
encoder_embed_dim=768
encoder_depth=12
encoder_num_heads=12
decoder_num_classes=768
decoder_embed_dim=512
decoder_depth=8
decoder_num_heads=8
mlp_ratio=4.
qkv_bias=False
qk_scale=None
drop_rate=0.
attn_drop_rate=0.
drop_path_rate=0.
norm_layer=nn.LayerNorm
init_values=0.
use_learnable_pos_emb=False
num_classes=0 # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
4 代碼部分-encoder
class?PretrainVisionTransformerEncoder(nn.Module):"""?Vision?Transformer?with?support?for?patch?or?hybrid?CNN?input?stage"""def?__init__(self,?img_size=224,?patch_size=16,?in_chans=3,?num_classes=0,?embed_dim=768,?depth=12,num_heads=12,?mlp_ratio=4.,?qkv_bias=False,?qk_scale=None,?drop_rate=0.,?attn_drop_rate=0.,drop_path_rate=0.,?norm_layer=nn.LayerNorm,?init_values=None,use_learnable_pos_emb=False):super().__init__()self.num_classes?=?num_classesself.num_features?=?self.embed_dim?=?embed_dim??#?num_features?for?consistency?with?other?modelsself.patch_embed?=?PatchEmbed(img_size=img_size,?patch_size=patch_size,?in_chans=in_chans,?embed_dim=embed_dim)num_patches?=?self.patch_embed.num_patches#?TODO:?Add?the?cls?token#?self.cls_token?=?nn.Parameter(torch.zeros(1,?1,?embed_dim))if?use_learnable_pos_emb:self.pos_embed?=?nn.Parameter(torch.zeros(1,?num_patches?+?1,?embed_dim))else:#?sine-cosine?positional?embeddings?self.pos_embed?=?get_sinusoid_encoding_table(num_patches,?embed_dim)dpr?=?[x.item()?for?x?in?torch.linspace(0,?drop_path_rate,?depth)]??#?stochastic?depth?decay?ruleself.blocks?=?nn.ModuleList([Block(dim=embed_dim,?num_heads=num_heads,?mlp_ratio=mlp_ratio,?qkv_bias=qkv_bias,?qk_scale=qk_scale,drop=drop_rate,?attn_drop=attn_drop_rate,?drop_path=dpr[i],?norm_layer=norm_layer,init_values=init_values)for?i?in?range(depth)])self.norm?=??norm_layer(embed_dim)self.head?=?nn.Linear(embed_dim,?num_classes)?if?num_classes?>?0?else?nn.Identity()if?use_learnable_pos_emb:trunc_normal_(self.pos_embed,?std=.02)#?trunc_normal_(self.cls_token,?std=.02)self.apply(self._init_weights)def?_init_weights(self,?m):if?isinstance(m,?nn.Linear):nn.init.xavier_uniform_(m.weight)if?isinstance(m,?nn.Linear)?and?m.bias?is?not?None:nn.init.constant_(m.bias,?0)elif?isinstance(m,?nn.LayerNorm):nn.init.constant_(m.bias,?0)nn.init.constant_(m.weight,?1.0)def?get_num_layers(self):return?len(self.blocks)@torch.jit.ignoredef?no_weight_decay(self):return?{'pos_embed',?'cls_token'}def?get_classifier(self):return?self.headdef?reset_classifier(self,?num_classes,?global_pool=''):self.num_classes?=?num_classesself.head?=?nn.Linear(self.embed_dim,?num_classes)?if?num_classes?>?0?else?nn.Identity()def?forward_features(self,?x,?mask):x?=?self.patch_embed(x)#?cls_tokens?=?self.cls_token.expand(batch_size,?-1,?-1)?#?x?=?torch.cat((cls_tokens,?x),?dim=1)x?=?x?+?self.pos_embed.type_as(x).to(x.device).clone().detach()B,?_,?C?=?x.shapex_vis?=?x[~mask].reshape(B,?-1,?C)?#?~mask?means?visiblefor?blk?in?self.blocks:x_vis?=?blk(x_vis)x_vis?=?self.norm(x_vis)return?x_visdef?forward(self,?x,?mask):x?=?self.forward_features(x,?mask)x?=?self.head(x)return?x構建Encoder中,用到了這幾個模塊:
self.patch_embed:將圖像patch化
depth個堆疊的Block,transformer的特征提取部分
self.head:這里是一個identity層,無意義。
5 代碼部分-patch_embed
class?PatchEmbed(nn.Module):"""?Image?to?Patch?Embedding"""def?__init__(self,?img_size=224,?patch_size=16,?in_chans=3,?embed_dim=768):super().__init__()img_size?=?to_2tuple(img_size)patch_size?=?to_2tuple(patch_size)num_patches?=?(img_size[1]?//?patch_size[1])?*?(img_size[0]?//?patch_size[0])self.patch_shape?=?(img_size[0]?//?patch_size[0],?img_size[1]?//?patch_size[1])self.img_size?=?img_sizeself.patch_size?=?patch_sizeself.num_patches?=?num_patchesself.proj?=?nn.Conv2d(in_chans,?embed_dim,?kernel_size=patch_size,?stride=patch_size)def?forward(self,?x,?**kwargs):B,?C,?H,?W?=?x.shape#?FIXME?look?at?relaxing?size?constraintsassert?H?==?self.img_size[0]?and?W?==?self.img_size[1],?\f"Input?image?size?({H}*{W})?doesn't?match?model?({self.img_size[0]}*{self.img_size[1]})."x?=?self.proj(x).flatten(2).transpose(1,?2)return?x這里面的代碼可以看到,其實只是包含一個self.proj(x)這一個卷積層罷了,我做了一個簡單的demo研究patchembed模塊是如何影響一個圖片的形狀的:
輸入是一個1x3x224x224的特征圖,輸出的y的形狀為:
這里我理解了這個過程以及兩個參數的含義:
196表示是一張圖片的patch的數量,224的輸入,16是patch的size,所以一個圖片有(224/16)的平方個patches,也就是196個patches;
每一個patch都被卷積編碼成了768維度的向量。768對應超參數embed_dim
這里面kernel_size和stride都設置成和patch尺度相同,其實是在數學上完全等價于,對一個patch的所有元素做了一層的全連接層。一個patch包含14x14個像素,也就是196 。這樣的卷積層等價于一個196到768的全連接層。
6 代碼部分-Block
class?Block(nn.Module):def?__init__(self,?dim,?num_heads,?mlp_ratio=4.,?qkv_bias=False,?qk_scale=None,?drop=0.,?attn_drop=0.,drop_path=0.,?init_values=None,?act_layer=nn.GELU,?norm_layer=nn.LayerNorm,attn_head_dim=None):super().__init__()self.norm1?=?norm_layer(dim)self.attn?=?Attention(dim,?num_heads=num_heads,?qkv_bias=qkv_bias,?qk_scale=qk_scale,attn_drop=attn_drop,?proj_drop=drop,?attn_head_dim=attn_head_dim)#?NOTE:?drop?path?for?stochastic?depth,?we?shall?see?if?this?is?better?than?dropout?hereself.drop_path?=?DropPath(drop_path)?if?drop_path?>?0.?else?nn.Identity()self.norm2?=?norm_layer(dim)mlp_hidden_dim?=?int(dim?*?mlp_ratio)self.mlp?=?Mlp(in_features=dim,?hidden_features=mlp_hidden_dim,?act_layer=act_layer,?drop=drop)if?init_values?>?0:self.gamma_1?=?nn.Parameter(init_values?*?torch.ones((dim)),requires_grad=True)self.gamma_2?=?nn.Parameter(init_values?*?torch.ones((dim)),requires_grad=True)else:self.gamma_1,?self.gamma_2?=?None,?Nonedef?forward(self,?x):if?self.gamma_1?is?None:x?=?x?+?self.drop_path(self.attn(self.norm1(x)))x?=?x?+?self.drop_path(self.mlp(self.norm2(x)))else:x?=?x?+?self.drop_path(self.gamma_1?*?self.attn(self.norm1(x)))x?=?x?+?self.drop_path(self.gamma_2?*?self.mlp(self.norm2(x)))return?x這個Block里面包含了三個模塊,Attention,Mlp和DropPath.
輸入的x先經過Layer norm做歸一化,然后放到Attention當中,然后是DropPath,然后是Layer norm歸一化,然后時Mlp然后是DropPath。
6 代碼部分-Attention
class?Attention(nn.Module):def?__init__(self,?dim,?num_heads=8,?qkv_bias=False,?qk_scale=None,?attn_drop=0.,proj_drop=0.,?attn_head_dim=None):super().__init__()self.num_heads?=?num_headshead_dim?=?dim?//?num_headsif?attn_head_dim?is?not?None:head_dim?=?attn_head_dimall_head_dim?=?head_dim?*?self.num_headsself.scale?=?qk_scale?or?head_dim?**?-0.5self.qkv?=?nn.Linear(dim,?all_head_dim?*?3,?bias=False)if?qkv_bias:self.q_bias?=?nn.Parameter(torch.zeros(all_head_dim))self.v_bias?=?nn.Parameter(torch.zeros(all_head_dim))else:self.q_bias?=?Noneself.v_bias?=?Noneself.attn_drop?=?nn.Dropout(attn_drop)self.proj?=?nn.Linear(all_head_dim,?dim)self.proj_drop?=?nn.Dropout(proj_drop)def?forward(self,?x):B,?N,?C?=?x.shapeqkv_bias?=?Noneif?self.q_bias?is?not?None:qkv_bias?=?torch.cat((self.q_bias,?torch.zeros_like(self.v_bias,?requires_grad=False),?self.v_bias))#?qkv?=?self.qkv(x).reshape(B,?N,?3,?self.num_heads,?C?//?self.num_heads).permute(2,?0,?3,?1,?4)qkv?=?F.linear(input=x,?weight=self.qkv.weight,?bias=qkv_bias)qkv?=?qkv.reshape(B,?N,?3,?self.num_heads,?-1).permute(2,?0,?3,?1,?4)q,?k,?v?=?qkv[0],?qkv[1],?qkv[2]???#?make?torchscript?happy?(cannot?use?tensor?as?tuple)q?=?q?*?self.scaleattn?=?(q?@?k.transpose(-2,?-1))attn?=?attn.softmax(dim=-1)attn?=?self.attn_drop(attn)x?=?(attn?@?v).transpose(1,?2).reshape(B,?N,?-1)x?=?self.proj(x)x?=?self.proj_drop(x)return?x通過這一行全連接層,將輸入768個特征,擴展到2304維度,分別對應q,k,v三個變量。
通過reshape,將【batch,196,2304】reshape成【1,196,3,8,96】,然后轉置變成【3,1,8,196,96】.這個3,剛好分配給qkv。然后經過兩次矩陣的乘法,最終輸出還是[batch,196,768]維度。
【總結】:Attention其實就是特征提取模塊,輸入是[batch,196,768],輸出也是[batch,196,768].
7 代碼部分-Mlp
class?Mlp(nn.Module):def?__init__(self,?in_features,?hidden_features=None,?out_features=None,?act_layer=nn.GELU,?drop=0.):super().__init__()out_features?=?out_features?or?in_featureshidden_features?=?hidden_features?or?in_featuresself.fc1?=?nn.Linear(in_features,?hidden_features)self.act?=?act_layer()self.fc2?=?nn.Linear(hidden_features,?out_features)self.drop?=?nn.Dropout(drop)def?forward(self,?x):x?=?self.fc1(x)x?=?self.act(x)#?x?=?self.drop(x)#?commit?this?for?the?orignal?BERT?implement?x?=?self.fc2(x)x?=?self.drop(x)return?x這個MLP就是兩層全連接層,將768放大到768x4的維度,然后再變成768.
7 代碼部分-Decode
class?PretrainVisionTransformerDecoder(nn.Module):"""?Vision?Transformer?with?support?for?patch?or?hybrid?CNN?input?stage"""def?__init__(self,?patch_size=16,?num_classes=768,?embed_dim=768,?depth=12,num_heads=12,?mlp_ratio=4.,?qkv_bias=False,?qk_scale=None,?drop_rate=0.,?attn_drop_rate=0.,drop_path_rate=0.,?norm_layer=nn.LayerNorm,?init_values=None,?num_patches=196,):super().__init__()self.num_classes?=?num_classesassert?num_classes?==?3?*?patch_size?**?2self.num_features?=?self.embed_dim?=?embed_dim??#?num_features?for?consistency?with?other?modelsself.patch_size?=?patch_sizedpr?=?[x.item()?for?x?in?torch.linspace(0,?drop_path_rate,?depth)]??#?stochastic?depth?decay?ruleself.blocks?=?nn.ModuleList([Block(dim=embed_dim,?num_heads=num_heads,?mlp_ratio=mlp_ratio,?qkv_bias=qkv_bias,?qk_scale=qk_scale,drop=drop_rate,?attn_drop=attn_drop_rate,?drop_path=dpr[i],?norm_layer=norm_layer,init_values=init_values)for?i?in?range(depth)])self.norm?=??norm_layer(embed_dim)self.head?=?nn.Linear(embed_dim,?num_classes)?if?num_classes?>?0?else?nn.Identity()self.apply(self._init_weights)def?_init_weights(self,?m):if?isinstance(m,?nn.Linear):nn.init.xavier_uniform_(m.weight)if?isinstance(m,?nn.Linear)?and?m.bias?is?not?None:nn.init.constant_(m.bias,?0)elif?isinstance(m,?nn.LayerNorm):nn.init.constant_(m.bias,?0)nn.init.constant_(m.weight,?1.0)def?get_num_layers(self):return?len(self.blocks)@torch.jit.ignoredef?no_weight_decay(self):return?{'pos_embed',?'cls_token'}def?get_classifier(self):return?self.headdef?reset_classifier(self,?num_classes,?global_pool=''):self.num_classes?=?num_classesself.head?=?nn.Linear(self.embed_dim,?num_classes)?if?num_classes?>?0?else?nn.Identity()def?forward(self,?x,?return_token_num):for?blk?in?self.blocks:x?=?blk(x)if?return_token_num?>?0:x?=?self.head(self.norm(x[:,?-return_token_num:]))?#?only?return?the?mask?tokens?predict?pixelselse:x?=?self.head(self.norm(x))?#?[B,?N,?3*16^2]return?x不過總的來說,這個代碼復現和論文中的MAE還有有不同的地方。decoder部分有問題。之后自己再修正一下。
我覺得大致的問題在于,這個代碼中,encoder之后,decoder之前,缺少一個對于圖像位置的還原。就是下圖中的紅框的步驟:
不過這一步驟的有無,并不會影響模型的訓練,只是為了生成完整的重建圖形。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載黃海廣老師《機器學習課程》視頻課黃海廣老師《機器學習課程》711頁完整版課件本站qq群955171419,加入微信群請掃碼:
總結
以上是生活随笔為你收集整理的【深度学习】preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【深度学习】修改每张人像---Image
- 下一篇: 【深度学习】一文详解RNN及股票预测实战