ViG核心代码及网络结构图
生活随笔
收集整理的這篇文章主要介紹了
ViG核心代码及网络结构图
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
ti_vig
def pvig_ti_224_gelu(pretrained=False, **kwargs):class OptInit:def __init__(self, num_classes=1000, drop_path_rate=0.0, **kwargs):self.k = 9 # 鄰域的數(shù)目,默認(rèn)為9self.conv = 'mr' # 圖卷積層=mrself.act = 'gelu' # 激活層=geluself.norm = 'batch' # batch or instance normalization {batch, instance}self.bias = True # bias of conv layer True or Falseself.dropout = 0.0 # dropout rateself.use_dilation = True # use dilated knn or notself.epsilon = 0.2 # stochastic epsilon for gcnself.use_stochastic = False # stochastic for gcn, True or Falseself.drop_path = drop_path_rateself.blocks = [2,2,6,2] # number of basic blocks in the backboneself.channels = [48, 96, 240, 384] # number of channels of deep featuresself.n_classes = num_classes # Dimension of out_channelsself.emb_dims = 1024 # Dimension of embeddingsopt = OptInit(**kwargs)model = DeepGCN(opt)model.default_cfg = default_cfgs['vig_224_gelu']return modelDeepGCN
class DeepGCN(torch.nn.Module):def __init__(self, opt):super(DeepGCN, self).__init__()print(opt)k = opt.k # k=9act = opt.act # active method = gelunorm = opt.norm # norm = batchbias = opt.bias # bias = trueepsilon = opt.epsilon # epsilon = 0.2stochastic = opt.use_stochastic # use_stochastic = Falseconv = opt.conv # conv = mremb_dims = opt.emb_dims # emb_dims = 1024drop_path = opt.drop_path # drop_path = drop_path_rate = 0.0blocks = opt.blocks # blocks = [2, 2, 6, 2]self.n_blocks = sum(blocks) # n_blocks = 12channels = opt.channels # channels = [80, 160, 400, 640]reduce_ratios = [4, 2, 1, 1]# stochastic depth decay rule # dpr = 0.0 x 12dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]# num_knn = 9 x 12num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]# 最大擴(kuò)張 max_dilation = 49//9 = 5max_dilation = 49 // max(num_knn)# Stem(out_dim=80, act=gelu), output size = [h/4, w/4, 80]self.stem = Stem(out_dim=channels[0], act=act)# pos_embed = [1, 80, 56, 56]self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))HW = 224 // 4 * 224 // 4 # 3136self.backbone = nn.ModuleList([])idx = 0for i in range(len(blocks)): # [2, 2, 6, 2], i = 0 1 2 3if i > 0:self.backbone.append(Downsample(channels[i-1], channels[i]))HW = HW // 4 # 784for j in range(blocks[i]):self.backbone += [Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],relative_pos=True),FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx]))]idx += 1self.backbone = Seq(*self.backbone)## ----- this part x2 -----## Grapher(channel=80, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2, ## reduce_ratios=4, n=3136, drop_path=0.0, relative_pos=True),## FFN(80, 320, act=gelu, drop_path=0.0)## ------------------------## Downsample(80, 160)## HW = 784## ----- this part x2 -----## Grapher(channel=160, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2, ## reduce_ratios=4, n=784, drop_path=0.0, relative_pos=True),## FFN(160, 640, act=gelu, drop_path=0.0)## ------------------------## Downsample(160, 400)## HW = 196## ----- this part x6 -----## Grapher(channel=400, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2, ## reduce_ratios=4, n=196, drop_path=0.0, relative_pos=True),## FFN(400, 1600, act=gelu, drop_path=0.0)## ------------------------## Downsample(400, 640)## HW = 49## ----- this part x2 -----## Grapher(channel=640, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2, ## reduce_ratios=4, n=196, drop_path=0.0, relative_pos=True),## FFN(640, 2560, act=gelu, drop_path=0.0)## ------------------------self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),nn.BatchNorm2d(1024),act_layer(act),nn.Dropout(opt.dropout),nn.Conv2d(1024, opt.n_classes, 1, bias=True))self.model_init()def model_init(self):for m in self.modules():if isinstance(m, torch.nn.Conv2d):torch.nn.init.kaiming_normal_(m.weight)m.weight.requires_grad = Trueif m.bias is not None:m.bias.data.zero_()m.bias.requires_grad = Truedef forward(self, inputs):x = self.stem(inputs) + self.pos_embedB, C, H, W = x.shapefor i in range(len(self.backbone)):x = self.backbone[i](x)x = F.adaptive_avg_pool2d(x, 1)return self.prediction(x).squeeze(-1).squeeze(-1)Stem
class Stem(nn.Module):""" Image to Visual EmbeddingOverlap: https://arxiv.org/pdf/2106.13797.pdf"""def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):super().__init__() self.convs = nn.Sequential(nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1), # in_ch=3, out_ch=40, outputsize=[h/2,w/2,40]nn.BatchNorm2d(out_dim//2), # 40act_layer(act), # relunn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1), # in_ch=40, out_ch=80, outputsize=[h/4,w/4,80]nn.BatchNorm2d(out_dim), # 80act_layer(act), # relunn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1), # in_ch=80, out_ch=80, outputsize=[h/4,w/4,80]nn.BatchNorm2d(out_dim),)def forward(self, x):x = self.convs(x)return x總結(jié)
以上是生活随笔為你收集整理的ViG核心代码及网络结构图的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 用Python脚本能获取Wifi密码么?
- 下一篇: OllyDbg使用教程