深度学习总结:cycleGAN原理,实现图片风格切换,可以和之前的伪DL方式对比一下,pytoch实现
文章目錄
- cycleGAN原理
- 一般的unsupervised conditional generation的處理辦法
- cycleGAN處理unsupervised conditional generation的辦法:
- 比較正常的思路:
- cycleGAN的思路:
- cycleGAN實(shí)現(xiàn):
- Discriminator的結(jié)構(gòu):
- Discriminator的Loss:
- Generator的結(jié)構(gòu):
- Generator的結(jié)構(gòu)圖:
- Residual Block:
- Autoencoder實(shí)現(xiàn):
- Generator的Loss:
- 整體結(jié)構(gòu):
- 優(yōu)化器:
- 訓(xùn)練過程:
cycleGAN原理
一般的unsupervised conditional generation的處理辦法
參考一下,信息量很大
cycleGAN處理unsupervised conditional generation的辦法:
比較正常的思路:
給一個(gè)domian X的image_x, 通過NN后變成domian Y的image_x,問題就解決了,但是NN學(xué)習(xí)需要目標(biāo)啊,目標(biāo)就是image_x是不是domian Y,還有就是和image_x像不像,也就是上圖中的real or not , match or not,這個(gè)實(shí)際上可以假如GAN的framework,在保證像的同時(shí),讓domain 越來越match。
前面提到過一種思路:使用vgg的features來保證圖片像,使用features內(nèi)部之間的內(nèi)積和表示圖片風(fēng)格,一種偽DL的實(shí)現(xiàn)方法:https://blog.csdn.net/weixin_40759186/article/details/87804316
我們一般的GAN的Discriminator都是一個(gè)目標(biāo),怎么實(shí)現(xiàn)real or not && match or not?real or not通過reconstruct error實(shí)現(xiàn),match or not通過訓(xùn)練一個(gè)domian classifier實(shí)現(xiàn)。
cycleGAN的思路:
real or not通過reconstruct error實(shí)現(xiàn),match or not通過訓(xùn)練一個(gè)domian classifier實(shí)現(xiàn)。
match or not通過訓(xùn)練一個(gè)domian classifier實(shí)現(xiàn):generator 生成的fake data和樣本的real data訓(xùn)練domian classifier,可以得到兩個(gè)Discriminator: D_x和D_y。
real or not通過reconstruct error實(shí)現(xiàn):x_domainx —>generator 生成的fake data: x_domainy—>generator 生成的reconstruct x^ _domianx,比較x^_domianx和x_domainx。
cycleGAN實(shí)現(xiàn):
Discriminator的結(jié)構(gòu):
class Discriminator(nn.Module):def __init__(self, conv_dim=64):super(Discriminator, self).__init__()# Define all convolutional layers# Should accept an RGB image as input and output a single value# Convolutional layers, increasing in depth# first layer has *no* batchnormself.conv1 = conv(3, conv_dim, 4, batch_norm=False) # x, y = 64, depth 64self.conv2 = conv(conv_dim, conv_dim*2, 4) # (32, 32, 128)self.conv3 = conv(conv_dim*2, conv_dim*4, 4) # (16, 16, 256)self.conv4 = conv(conv_dim*4, conv_dim*8, 4) # (8, 8, 512)# Classification layerself.conv5 = conv(conv_dim*8, 1, 4, stride=1, batch_norm=False)def forward(self, x):# relu applied to all conv layers but lastout = F.relu(self.conv1(x))out = F.relu(self.conv2(out))out = F.relu(self.conv3(out))out = F.relu(self.conv4(out))# last, classification layerout = self.conv5(out)return outDiscriminator的Loss:
這個(gè)就是平常的GAN的Discriminator,讓real loss底,fake loss高
def real_mse_loss(D_out):# how close is the produced output from being "real"?return torch.mean((D_out-1)**2)def fake_mse_loss(D_out):# how close is the produced output from being "false"?return torch.mean(D_out**2)## First: D_X, real and fake loss components ### Train with real imagesd_x_optimizer.zero_grad()# 1. Compute the discriminator losses on real imagesout_x = D_X(images_X)D_X_real_loss = real_mse_loss(out_x)# Train with fake images# 2. Generate fake images that look like domain X based on real images in domain Yfake_X = G_YtoX(images_Y)# 3. Compute the fake loss for D_Xout_x = D_X(fake_X)D_X_fake_loss = fake_mse_loss(out_x)# 4. Compute the total loss and perform backpropd_x_loss = D_X_real_loss + D_X_fake_lossGenerator的結(jié)構(gòu):
Generator的結(jié)構(gòu)圖:
這個(gè)一個(gè)典型的autoencoder結(jié)構(gòu)
Residual Block:
# residual block class class ResidualBlock(nn.Module):"""Defines a residual block.This adds an input x to a convolutional layer (applied to x) with the same size input and output.These blocks allow a model to learn an effective transformation from one domain to another."""def __init__(self, conv_dim):super(ResidualBlock, self).__init__()# conv_dim = number of inputs# define two convolutional layers + batch normalization that will act as our residual function, F(x)# layers should have the same shape input as output; I suggest a kernel_size of 3self.conv_layer1 = conv(in_channels=conv_dim, out_channels=conv_dim, kernel_size=3, stride=1, padding=1, batch_norm=True)self.conv_layer2 = conv(in_channels=conv_dim, out_channels=conv_dim, kernel_size=3, stride=1, padding=1, batch_norm=True)def forward(self, x):# apply a ReLu activation the outputs of the first layer# return a summed output, x + resnet_block(x)out_1 = F.relu(self.conv_layer1(x))out_2 = x + self.conv_layer2(out_1)return out_2Autoencoder實(shí)現(xiàn):
class CycleGenerator(nn.Module):def __init__(self, conv_dim=64, n_res_blocks=6):super(CycleGenerator, self).__init__()# 1. Define the encoder part of the generator# initial convolutional layer given, belowself.conv1 = conv(3, conv_dim, 4)self.conv2 = conv(conv_dim, conv_dim*2, 4)self.conv3 = conv(conv_dim*2, conv_dim*4, 4)# 2. Define the resnet part of the generator# Residual blocksres_layers = []for layer in range(n_res_blocks):res_layers.append(ResidualBlock(conv_dim*4))# use sequential to create these layersself.res_blocks = nn.Sequential(*res_layers)# 3. Define the decoder part of the generator# two transpose convolutional layers and a third that looks a lot like the initial conv layerself.deconv1 = deconv(conv_dim*4, conv_dim*2, 4)self.deconv2 = deconv(conv_dim*2, conv_dim, 4)# no batch norm on last layerself.deconv3 = deconv(conv_dim, 3, 4, batch_norm=False)def forward(self, x):"""Given an image x, returns a transformed image."""# define feedforward behavior, applying activations as necessaryout = F.relu(self.conv1(x))out = F.relu(self.conv2(out))out = F.relu(self.conv3(out))out = self.res_blocks(out)out = F.relu(self.deconv1(out))out = F.relu(self.deconv2(out))# tanh applied to last layerout = F.tanh(self.deconv3(out))return out# helper deconv function def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):"""Creates a transpose convolutional layer, with optional batch normalization."""layers = []# append transpose conv layerlayers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))# optional batch norm layerif batch_norm:layers.append(nn.BatchNorm2d(out_channels))return nn.Sequential(*layers)Generator的Loss:
注意這里是一個(gè)cycle之后再去更新G1和G2的梯度
def real_mse_loss(D_out):# how close is the produced output from being "real"?return torch.mean((D_out-1)**2)def fake_mse_loss(D_out):# how close is the produced output from being "false"?return torch.mean(D_out**2)def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):# calculate reconstruction loss # as absolute value difference between the real and reconstructed imagesreconstr_loss = torch.mean(torch.abs(real_im - reconstructed_im))# return weighted lossreturn lambda_weight*reconstr_loss# =========================================# TRAIN THE GENERATORS# =========================================## First: generate fake X images and reconstructed Y images ##g_optimizer.zero_grad()# 1. Generate fake images that look like domain X based on real images in domain Yfake_X = G_YtoX(images_Y)# 2. Compute the generator loss based on domain Xout_x = D_X(fake_X)g_YtoX_loss = real_mse_loss(out_x)# 3. Create a reconstructed y# 4. Compute the cycle consistency loss (the reconstruction loss)reconstructed_Y = G_XtoY(fake_X)reconstructed_y_loss = cycle_consistency_loss(images_Y, reconstructed_Y, lambda_weight=10)## Second: generate fake Y images and reconstructed X images ### 1. Generate fake images that look like domain Y based on real images in domain Xfake_Y = G_XtoY(images_X)# 2. Compute the generator loss based on domain Yout_y = D_Y(fake_Y)g_XtoY_loss = real_mse_loss(out_y)# 3. Create a reconstructed x# 4. Compute the cycle consistency loss (the reconstruction loss)reconstructed_X = G_YtoX(fake_Y)reconstructed_x_loss = cycle_consistency_loss(images_X, reconstructed_X, lambda_weight=10)# 5. Add up all generator and reconstructed losses and perform backpropg_total_loss = g_YtoX_loss + g_XtoY_loss + reconstructed_y_loss + reconstructed_x_loss整體結(jié)構(gòu):
需要訓(xùn)練2個(gè)Discriminators和2個(gè)Generators:
有人可能要問訓(xùn)練一個(gè)Generator可不可以? X->G->Y, Y->G->X,真把Generator當(dāng)作萬能的?
優(yōu)化器:
看到?jīng)]G_XtoY和G_YtoX是在一個(gè)cycle更新的
import torch.optim as optim# hyperparams for Adam optimizer lr=0.0002 beta1=0.5 beta2=0.999 # default valueg_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters()) # Get generator parameters# Create optimizers for the generators and discriminators g_optimizer = optim.Adam(g_params, lr, [beta1, beta2]) d_x_optimizer = optim.Adam(D_X.parameters(), lr, [beta1, beta2]) d_y_optimizer = optim.Adam(D_Y.parameters(), lr, [beta1, beta2])訓(xùn)練過程:
先訓(xùn)練D_X,先訓(xùn)練D_Y,再訓(xùn)練G_XtoY和G_YtoX
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y, n_epochs=1000):print_every=10# keep track of losses over timelosses = []test_iter_X = iter(test_dataloader_X)test_iter_Y = iter(test_dataloader_Y)# Get some fixed data from domains X and Y for sampling. These are images that are held# constant throughout training, that allow us to inspect the model's performance.fixed_X = test_iter_X.next()[0]fixed_Y = test_iter_Y.next()[0]fixed_X = scale(fixed_X) # make sure to scale to a range -1 to 1fixed_Y = scale(fixed_Y)# batches per epochiter_X = iter(dataloader_X)iter_Y = iter(dataloader_Y)batches_per_epoch = min(len(iter_X), len(iter_Y))for epoch in range(1, n_epochs+1):# Reset iterators for each epochif epoch % batches_per_epoch == 0:iter_X = iter(dataloader_X)iter_Y = iter(dataloader_Y)images_X, _ = iter_X.next()images_X = scale(images_X) # make sure to scale to a range -1 to 1images_Y, _ = iter_Y.next()images_Y = scale(images_Y)# move images to GPU if available (otherwise stay on CPU)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")images_X = images_X.to(device)images_Y = images_Y.to(device)# ============================================# TRAIN THE DISCRIMINATORS# ============================================## First: D_X, real and fake loss components ### Train with real imagesd_x_optimizer.zero_grad()# 1. Compute the discriminator losses on real imagesout_x = D_X(images_X)D_X_real_loss = real_mse_loss(out_x)# Train with fake images# 2. Generate fake images that look like domain X based on real images in domain Yfake_X = G_YtoX(images_Y)# 3. Compute the fake loss for D_Xout_x = D_X(fake_X)D_X_fake_loss = fake_mse_loss(out_x)# 4. Compute the total loss and perform backpropd_x_loss = D_X_real_loss + D_X_fake_lossd_x_loss.backward()d_x_optimizer.step()## Second: D_Y, real and fake loss components ### Train with real imagesd_y_optimizer.zero_grad()# 1. Compute the discriminator losses on real imagesout_y = D_Y(images_Y)D_Y_real_loss = real_mse_loss(out_y)# Train with fake images# 2. Generate fake images that look like domain Y based on real images in domain Xfake_Y = G_XtoY(images_X)# 3. Compute the fake loss for D_Yout_y = D_Y(fake_Y)D_Y_fake_loss = fake_mse_loss(out_y)# 4. Compute the total loss and perform backpropd_y_loss = D_Y_real_loss + D_Y_fake_lossd_y_loss.backward()d_y_optimizer.step()# =========================================# TRAIN THE GENERATORS# =========================================## First: generate fake X images and reconstructed Y images ##g_optimizer.zero_grad()# 1. Generate fake images that look like domain X based on real images in domain Yfake_X = G_YtoX(images_Y)# 2. Compute the generator loss based on domain Xout_x = D_X(fake_X)g_YtoX_loss = real_mse_loss(out_x)# 3. Create a reconstructed y# 4. Compute the cycle consistency loss (the reconstruction loss)reconstructed_Y = G_XtoY(fake_X)reconstructed_y_loss = cycle_consistency_loss(images_Y, reconstructed_Y, lambda_weight=10)## Second: generate fake Y images and reconstructed X images ### 1. Generate fake images that look like domain Y based on real images in domain Xfake_Y = G_XtoY(images_X)# 2. Compute the generator loss based on domain Yout_y = D_Y(fake_Y)g_XtoY_loss = real_mse_loss(out_y)# 3. Create a reconstructed x# 4. Compute the cycle consistency loss (the reconstruction loss)reconstructed_X = G_YtoX(fake_Y)reconstructed_x_loss = cycle_consistency_loss(images_X, reconstructed_X, lambda_weight=10)# 5. Add up all generator and reconstructed losses and perform backpropg_total_loss = g_YtoX_loss + g_XtoY_loss + reconstructed_y_loss + reconstructed_x_lossg_total_loss.backward()g_optimizer.step()總結(jié)
以上是生活随笔為你收集整理的深度学习总结:cycleGAN原理,实现图片风格切换,可以和之前的伪DL方式对比一下,pytoch实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学总结:RNN训练需要注意地方:py
- 下一篇: window 10下 Spark 安装简