pytorch复现经典生成对抗式的超分辨率网络
生活随笔
收集整理的這篇文章主要介紹了
pytorch复现经典生成对抗式的超分辨率网络
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
論文原文:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
論文的中文翻譯:翻譯:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
網絡結構如下圖所示:
上面和下面分別是生成網絡和判別網絡:
廢話不多說,直接看代碼。比較不喜歡一堆廢話的博客,懂得都懂,最有用的就是代碼!
代碼的實現參考pytorch torchvision中的網絡實現優點:模塊化、簡潔易讀、而且容易修改網絡寬度和深度(方便調整網絡架構做對比實驗,消融實驗)。
實現生成網絡:
# -*- coding: utf-8 -*- # @Use : # @Time : 2022/8/17 21:32 # @FileName: models.py # @Software: PyCharm # @inference:https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.pyimport torch from torch import nn import torchvision from torch import Tensorclass GeneratorBasicBlock(nn.Module):"""生成器重復的部分"""def __init__(self, channel, kernel_size) -> None:super(GeneratorBasicBlock, self).__init__()self.channel = channelself.conv1 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(num_features=channel)self.p_relu1 = nn.PReLU()self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn2 = nn.BatchNorm2d(num_features=channel)def forward(self, x: Tensor) -> Tensor:"""前向推斷:param x::return:"""identity = xout = self.conv1(x)out = self.bn1(out)out = self.p_relu1(out)out = self.conv2(out)out = self.bn2(out)out += identityreturn outclass PixelShufflerBlock(nn.Module):"""生成器最后的pixelshuffler"""def __init__(self, in_channel, out_channel) -> None:super(PixelShufflerBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.pixels_shuffle = nn.PixelShuffle(upscale_factor=2)self.prelu = nn.PReLU()def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.pixels_shuffle(out)out = self.prelu(out)return outclass Generator(nn.Module):"""生成器"""def __init__(self, config) -> None:# Generator parameterssuper(Generator, self).__init__()large_kernel_size = config.G.large_kernel_size # = 9small_kernel_size = config.G.small_kernel_size # = 3n_channels = config.G.n_channels # = 64n_blocks = config.G.n_blocks # = 16base_block_type = config.G.base_block_type # 'depthwise_conv_residual' # 'conv_residual' or 'depthwise_conv_residual'# base blockif base_block_type == 'depthwise_conv_residual':self.repeat_block = GeneratorDepthwiseBlockif base_block_type == 'conv_residual':self.repeat_block = GeneratorBasicBlockself.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))self.prelu1 = nn.PReLU()self.B_residul_block = self._make_layer(self.repeat_block, n_channels,n_blocks, small_kernel_size)self.conv2 = nn.Conv2d(in_channels=n_channels, out_channels=n_channels,kernel_size=(small_kernel_size, small_kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(n_channels)self.pixel_shuffle_block1 = PixelShufflerBlock(n_channels, 4 * n_channels)self.pixel_shuffle_block2 = PixelShufflerBlock(n_channels, 4 * n_channels)self.conv3 = nn.Conv2d(in_channels=n_channels, out_channels=3,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))def _make_layer(self, base_block, n_channels, n_block, kernel_size) -> nn.Sequential:"""構建重復的B個基本塊:param base_block: 基本塊:param n_channels: 塊里面的通道數:param n_block: 塊數:return:"""layers = []self.base_block = base_blockfor _ in range(n_block):layers.append(self.base_block(n_channels, kernel_size))return nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向的實現"""out = self.conv1(x)out = self.prelu1(out)identity = outout = self.B_residul_block(out)out = self.conv2(out)out = self.bn1(out)out += identityout = self.pixel_shuffle_block1(out)out = self.pixel_shuffle_block2(out)out = self.conv3(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)判別網絡實現:
class DiscriminatorBaseblock(nn.Module):"""判別器的基本塊"""def __init__(self, inchannel, out_chanel, kernel_size, stride):super(DiscriminatorBaseblock, self).__init__()self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=out_chanel,kernel_size=kernel_size, stride=stride, padding=(1, 1))self.bn1 = nn.BatchNorm2d(out_chanel)self.act1 = nn.LeakyReLU(0.2)def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.bn1(out)out = self.act1(out)return outclass Discriminator(nn.Module):"""判別器""" def __init__(self, config):super(Discriminator, self).__init__()# Discriminator parameterskernel_size = config.D.kernel_size = 3n_channels = config.D.n_channels = 64n_blocks = config.D.n_blocks = 8fc_size = config.D.fc_size = 1024self.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(kernel_size, kernel_size), stride=(1, 1), padding=(1, 1))self.leaky_relu1 = nn.LeakyReLU(0.2)conv_configs = [[3, 64, 2], # 配置二維數組[3, 128, 1],[3, 128, 2],[3, 256, 1],[3, 256, 2],[3, 512, 1],[3, 512, 2]]self.base_blocks = self._make_layer(n_channels, DiscriminatorBaseblock, conv_configs)self.dense1 = nn.Linear(512 * 6 * 6, 1024)self.leaky_relu2 = nn.LeakyReLU(0.2)self.dense2 = nn.Linear(1024, 1)self.sigmod1 = nn.Sigmoid()def _make_layer(self, in_channel, base_block, conv_configs: list) -> nn.Sequential:""":param base_block: DiscriminatorBaseblock:param conv_configs: (kernel, channel, stride):return:"""layers = []self.base_block = base_blockself.in_channel = in_channelfor i in range(len(conv_configs)):# in_channel, out_chanel, kernel_size, stridekernel_size = (conv_configs[i][0], conv_configs[i][0])stride = (conv_configs[i][2], conv_configs[i][2])out_channel = conv_configs[i][1]layers.append(self.base_block(self.in_channel, out_channel, kernel_size, stride))self.in_channel = out_channelreturn nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.leaky_relu1(out)out = self.base_blocks(out)out = torch.flatten(out, 1)out = self.dense1(out)out = self.leaky_relu2(out)out = self.dense2(out)out = self.sigmod1(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)?
總結
以上是生活随笔為你收集整理的pytorch复现经典生成对抗式的超分辨率网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Flask实现简单搜索功能
- 下一篇: 懐かしい