如何利用CycleGAN实现男女性别转换
介紹
CycleGAN網(wǎng)絡(luò)具有很強(qiáng)大的風(fēng)格遷移功能。能夠?qū)崿F(xiàn)非常深層次的風(fēng)格轉(zhuǎn)換。比如男性圖片女性化或者女性圖片男性化。
先上效果圖:
下面簡單談一談實現(xiàn)原理。
網(wǎng)絡(luò)結(jié)構(gòu)
網(wǎng)絡(luò)結(jié)構(gòu)如圖所示,通過兩個循環(huán)使用的生成器來進(jìn)行風(fēng)格遷移。由此實現(xiàn)了非常神奇的效果。
下面結(jié)合代碼來詳細(xì)解釋一下網(wǎng)絡(luò)結(jié)構(gòu)。訓(xùn)練生成對抗網(wǎng)絡(luò)的深度學(xué)習(xí)框架為Pytorch。
1. 殘差模塊定義
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()# 殘差模塊不改變shapeconv_block = [ nn.ReflectionPad2d(1), # 構(gòu)建殘差模塊的時候使用映射填充的形式nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features), # 不使用BatchNorm而是使用InstanceNormnn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features) ]self.conv_block = nn.Sequential(*conv_block)def forward(self, x):return x + self.conv_block(x)殘差模塊的定義沒有太多需要說明的地方,就是有一點需要注意的是。我們在風(fēng)格遷移中,不再使用BatchNorm而是使用InstanceNorm。
BN是將每一個batch的每一個通道的每一組圖片求mean和var, IN是將單獨一個圖片的一個通道的數(shù)據(jù)求mean和var。 區(qū)別就是一個是對batch求,一個是對一個圖片求。風(fēng)格遷移中,為了保證風(fēng)格,通常都對每一個圖片單獨處理。 CycleGAN網(wǎng)絡(luò)中,每一個batch只有一張 圖片,所以使用InstanceNorm。
2. 定義生成器
class Generator(nn.Module):def __init__(self, input_nc, output_nc, n_residual_blocks=9):"""定義生成網(wǎng)絡(luò)參數(shù):input_nc --輸入通道數(shù)output_nc --輸出通道數(shù)n_residual_blocks --殘差模塊數(shù)量"""super(Generator, self).__init__()# 初始化卷積模塊# 因為使用ReflectionPad擴(kuò)充# 所以輸入是3*256*256# 輸出是64*256*256model = [ nn.ReflectionPad2d(3),nn.Conv2d(input_nc, 64, 7),nn.InstanceNorm2d(64),nn.ReLU(inplace=True) ]# 進(jìn)行下采樣# 第一個range:輸入是64*256*256,輸出是128*128*128# 第二個range:輸入是128*128*128,輸出是256*64*64in_features = 64out_features = in_features*2for _ in range(2):model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True) ]in_features = out_featuresout_features = in_features*2# 使用殘差模塊# 輸入輸出都是256*64*64for _ in range(n_residual_blocks): # 默認(rèn)添加9個殘差模塊model += [ResidualBlock(in_features)]# 進(jìn)行上采樣# 第一個range:輸入是256*64*64,輸出是128*128*128# 第二個range:輸入是128*128*128,輸出是64*256*256 out_features = in_features//2for _ in range(2):model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True) ]in_features = out_featuresout_features = in_features//2# 最后輸出層# 輸入是64*256*256# 輸出是3*256*256model += [ nn.ReflectionPad2d(3),nn.Conv2d(64, output_nc, 7),nn.Tanh() ]self.model = nn.Sequential(*model)def forward(self, x):return self.model(x)生成器的結(jié)構(gòu)就是最初那幅圖中的右側(cè)的樣子。進(jìn)行下采樣之后接一個殘差模塊,再之后進(jìn)行上采樣。生成器期望可以學(xué)到比較復(fù)雜的特征構(gòu)造方法,所以網(wǎng)絡(luò)結(jié)構(gòu)更深,更復(fù)雜。判別器結(jié)構(gòu)相對來說要簡單很多。
3. 判別器
class Discriminator(nn.Module):def __init__(self, input_nc):super(Discriminator, self).__init__()# 構(gòu)建卷積分類器# 輸入為3*256*256# 輸出為64*128*128model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),nn.LeakyReLU(0.2, inplace=True) ]# 輸入為64*128*128# 輸出為128*64*64model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True) ]# 輸入為128*64*64# 輸出為256*32*32model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True) ]# 輸入為256*32*32# 輸出為512*31*31model += [ nn.Conv2d(256, 512, 4, padding=1),nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True) ]# 全卷積分類層# 輸入為輸出為512*31*31# 輸出為1*30*30model += [nn.Conv2d(512, 1, 4, padding=1)]self.model = nn.Sequential(*model)def forward(self, x):x = self.model(x)# 使用平均池化的辦法輸出預(yù)測值# avg_pool2d(input,kernel_size),這里kernel_size為30return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)就是一個比較普通的分類網(wǎng)絡(luò)。通過步長為2來逐步縮小尺寸。可能值得注意的是,相比于傳統(tǒng)的分類神經(jīng)網(wǎng)絡(luò)。我們這里使用全局平均池化的方式進(jìn)行最終輸出預(yù)測。沒有使用全連接層,減小了網(wǎng)絡(luò)尺寸。
此外,我還做了一個exe交互程序。可以直接運行,實現(xiàn)圖片中頭像識別和對應(yīng)性別轉(zhuǎn)換。可以體驗一下生成對抗網(wǎng)絡(luò)的趣味。
對網(wǎng)絡(luò)感興趣,以及想要詳細(xì)了解原理是具體如何用代碼實現(xiàn),或者想用有趣數(shù)據(jù)集做出創(chuàng)意應(yīng)用的功能的話,可以參考這個視頻課程:點擊鏈接
總結(jié)
以上是生活随笔為你收集整理的如何利用CycleGAN实现男女性别转换的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何在Ubuntu18.04下安装CUD
- 下一篇: Faster RCNN网络简介