YoloV3网络模型搭建
生活随笔
收集整理的這篇文章主要介紹了
YoloV3网络模型搭建
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
結構圖是復制別人的
?
import torch import torch.nn as nn from collections import OrderedDictclass CBL(nn.Module):def __init__(self, channel_in, channel_out, ks, p=1, strides=(1, 1)):super(CBL, self).__init__()self.block = nn.Sequential(nn.Conv2d(channel_in, channel_out, ks, padding=p),nn.BatchNorm2d(channel_out),nn.LeakyReLU(0.1))def forward(self, x):return self.block(x)class ResidualBlock(nn.Module):# 利用一個1x1卷積下降通道數===>利用一個3x3卷積提取特征===>利用一個1x1卷積上升通道數def __init__(self, inp, planes):"""這個類需要傳進來兩個參數,一個是【input_channel,output_channel】"""super(ResidualBlock, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(inp, planes[0], kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(planes[0]),nn.LeakyReLU(0.1))self.conv2 = nn.Sequential(nn.Conv2d(planes[0], planes[1], kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(planes[1]),nn.LeakyReLU(0.1))self.conv3 = nn.Sequential(nn.Conv2d(planes[1], planes[1], kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(planes[1]),nn.LeakyReLU(0.1))def forward(self, inputs):x = self.conv1(inputs)x = self.conv2(x)x = self.conv3(x)return (x + inputs)class DarkNet(nn.Module):def __init__(self, blocks):super(DarkNet, self).__init__()self.inp = 32self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=self.inp, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.inp),nn.LeakyReLU(0.1))self.block1 = nn.Sequential(self._make_residual_layer(planes=[32, 64], block=blocks[0]),self._make_residual_layer(planes=[64, 128], block=blocks[1]),self._make_residual_layer(planes=[128, 256], block=blocks[2]))self.block2 = self._make_residual_layer(planes=[256, 512], block=blocks[3])self.block3 = self._make_residual_layer(planes=[512, 1024], block=blocks[4])def forward(self, inputs):########################## 1.Darknet部分#########################x = self.conv1(inputs)x = self.block1(x) # 包括前三個resfeat1 = xx = self.block2(x)feat2 = xx = self.block3(x)feat3 = xreturn feat1, feat2, feat3def _make_residual_layer(self, planes, block):# planes是【input_channel,output_channel】,blocks是重復執行殘差的次數layers = []for i in range(block):if i < 1: # 第一次由于channel不一樣,所以需要對通道數進行調整layers.append(("ds_conv", nn.Conv2d(self.inp, planes[1], kernel_size=3, stride=2, padding=1, bias=False)))layers.append(("ds_bn", nn.BatchNorm2d(planes[1])))layers.append(("ds_relu", nn.LeakyReLU(0.1)))self.inp = planes[1]else:layers.append(("residual_{}".format(i), ResidualBlock(self.inp, planes)))return nn.Sequential(OrderedDict(layers))class YoloBody(nn.Module):def __init__(self, num_classes):super(YoloBody, self).__init__()self.darknet = DarkNet([1, 2, 8, 8, 4])self.num_classes = num_classesself.up1 = nn.Sequential(CBL(512, 256, 1, 0),nn.UpsamplingBilinear2d(scale_factor=2))self.up2 = nn.Sequential(CBL(256, 128, 1, 0),nn.UpsamplingBilinear2d(scale_factor=2))def forward(self, inputs):feat1, feat2, feat3 = self.darknet(inputs)# print(feat1.shape)########################## 2.加強特征提取#########################big_map, x = self.make_five_conv(feat3, 1024, 512)big_output = self.output_conv(big_map, 512, (self.num_classes+5)*3)x = self.up1(x)x = torch.cat([feat2, x],dim=1)middel_map, x = self.make_five_conv(x, 768, 256)middel_output = self.output_conv(middel_map, 256, (self.num_classes+5)*3)x = self.up2(x)x = torch.cat([feat1, x], dim=1)small_map, _ = self.make_five_conv(x, 384, 128)small_output = self.output_conv(small_map, 128, (self.num_classes+5)*3)print(big_output.shape, middel_output.shape, small_output.shape)return (big_output, middel_output, small_output)def make_five_conv(self, x, channel_in, channel_out):x = CBL(channel_in, channel_out, ks=1, p=0)(x)x = CBL(channel_out, channel_out, ks=3, p=1)(x)x = CBL(channel_out, channel_out, ks=1, p=0)(x)x = CBL(channel_out, channel_out, ks=3, p=1)(x)x = CBL(channel_out, channel_out, ks=1, p=0)(x)return x,xdef output_conv(self, x, channel_in, channel_out):x = CBL(channel_in, channel_out, ks=3)(x)y = nn.Conv2d(channel_out, channel_out, kernel_size=1, padding=0)(x)return yif __name__ == '__main__':inputs = torch.zeros(size=(1,3, 416, 416))print(inputs.shape)# model = DarkNet(blocks=[1, 2, 8, 8, 4])model = YoloBody(num_classes=20)out = model.forward(inputs)總結
以上是生活随笔為你收集整理的YoloV3网络模型搭建的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python装饰器简单理解的小demo
- 下一篇: 2022Go安装goimports第三方