vgg11/13/16/19-pytorch实现
生活随笔
收集整理的這篇文章主要介紹了
vgg11/13/16/19-pytorch实现
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
vgg11/13/16/19-pytorch實現(xiàn)
import torch import torch.nn as nncfg = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], }class VGG(nn.Module):def __init__(self, feature, num_class=10):super().__init__()self.feature = featureself.classifier = nn.Sequential(nn.Linear(7 * 7 * 512, 4096),# nn.Linear(512, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_class))def forward(self, x):output = self.feature(x)output = output.view(output.size()[0], -1)output = self.classifier(output)return outputdef make_layers(cfg, batch_norm=False):layers = []input_channel = 3for l in cfg:if l == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]continuelayers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] # stride默認(rèn)為1,即保持圖像尺寸不變if batch_norm == True:layers += [nn.BatchNorm2d(l)]layers += [nn.ReLU(inplace=True)]input_channel = lreturn nn.Sequential(*layers) # *list能提取列表中的元素def vgg11_bn():return VGG(make_layers(cfg['A'], batch_norm=True)) def vgg13_bn():return VGG(make_layers(cfg['B'], batch_norm=True)) def vgg16_bn():return VGG(make_layers(cfg['C'], batch_norm=True)) def vgg19_bn():return VGG(make_layers(cfg['D'], batch_norm=True))if __name__ == '__main__':# l = [1, 2, 3]# print(*l)from torchsummary import summarydevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = vgg16_bn().to(device)summary(model, (3, 224, 224))Output
----------------------------------------------------------------Layer (type) Output Shape Param # ================================================================Conv2d-1 [-1, 64, 224, 224] 1,792BatchNorm2d-2 [-1, 64, 224, 224] 128ReLU-3 [-1, 64, 224, 224] 0Conv2d-4 [-1, 64, 224, 224] 36,928BatchNorm2d-5 [-1, 64, 224, 224] 128ReLU-6 [-1, 64, 224, 224] 0MaxPool2d-7 [-1, 64, 112, 112] 0Conv2d-8 [-1, 128, 112, 112] 73,856BatchNorm2d-9 [-1, 128, 112, 112] 256ReLU-10 [-1, 128, 112, 112] 0Conv2d-11 [-1, 128, 112, 112] 147,584BatchNorm2d-12 [-1, 128, 112, 112] 256ReLU-13 [-1, 128, 112, 112] 0MaxPool2d-14 [-1, 128, 56, 56] 0Conv2d-15 [-1, 256, 56, 56] 295,168BatchNorm2d-16 [-1, 256, 56, 56] 512ReLU-17 [-1, 256, 56, 56] 0Conv2d-18 [-1, 256, 56, 56] 590,080BatchNorm2d-19 [-1, 256, 56, 56] 512ReLU-20 [-1, 256, 56, 56] 0Conv2d-21 [-1, 256, 56, 56] 590,080BatchNorm2d-22 [-1, 256, 56, 56] 512ReLU-23 [-1, 256, 56, 56] 0MaxPool2d-24 [-1, 256, 28, 28] 0Conv2d-25 [-1, 512, 28, 28] 1,180,160BatchNorm2d-26 [-1, 512, 28, 28] 1,024ReLU-27 [-1, 512, 28, 28] 0Conv2d-28 [-1, 512, 28, 28] 2,359,808BatchNorm2d-29 [-1, 512, 28, 28] 1,024ReLU-30 [-1, 512, 28, 28] 0Conv2d-31 [-1, 512, 28, 28] 2,359,808BatchNorm2d-32 [-1, 512, 28, 28] 1,024ReLU-33 [-1, 512, 28, 28] 0MaxPool2d-34 [-1, 512, 14, 14] 0Conv2d-35 [-1, 512, 14, 14] 2,359,808BatchNorm2d-36 [-1, 512, 14, 14] 1,024ReLU-37 [-1, 512, 14, 14] 0Conv2d-38 [-1, 512, 14, 14] 2,359,808BatchNorm2d-39 [-1, 512, 14, 14] 1,024ReLU-40 [-1, 512, 14, 14] 0Conv2d-41 [-1, 512, 14, 14] 2,359,808BatchNorm2d-42 [-1, 512, 14, 14] 1,024ReLU-43 [-1, 512, 14, 14] 0MaxPool2d-44 [-1, 512, 7, 7] 0Linear-45 [-1, 4096] 102,764,544ReLU-46 [-1, 4096] 0Dropout-47 [-1, 4096] 0Linear-48 [-1, 4096] 16,781,312ReLU-49 [-1, 4096] 0Dropout-50 [-1, 4096] 0Linear-51 [-1, 10] 40,970 ================================================================ Total params: 134,309,962 Trainable params: 134,309,962 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 321.94 Params size (MB): 512.35 Estimated Total Size (MB): 834.87 ----------------------------------------------------------------Process finished with exit code 0總結(jié)
以上是生活随笔為你收集整理的vgg11/13/16/19-pytorch实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python 查看网络模型
- 下一篇: 机器学习经典算法之线性回归sklearn