生成式对抗网络的原理和实现方法
簡(jiǎn)介
- gan全稱:generative adversarial network
- 發(fā)明時(shí)間:2014年,Ian Goodfellow和Yoshua Bengio的實(shí)驗(yàn)室中相關(guān)人員。
- gan的作用:訓(xùn)練出一個(gè)“造假機(jī)器人”,造出來的東西跟真的幾乎類似。
- gan的實(shí)現(xiàn)原理:如何訓(xùn)練“造假機(jī)器人”?——兩個(gè)網(wǎng)絡(luò),一個(gè)生成器網(wǎng)絡(luò)GGG和一個(gè)鑒別器網(wǎng)絡(luò)DDD,兩者互相競(jìng)爭(zhēng)來提升自己。生成器就是“造假機(jī)器人”,把造出來的東西丟到鑒別器網(wǎng)絡(luò),鑒別器網(wǎng)絡(luò)要鑒別這東西到底來是真實(shí)數(shù)據(jù)還是造假數(shù)據(jù)。訓(xùn)練剛開始,生成器生成的東西幾乎是四不像,鑒別器鑒別的能力也幾乎是瞎猜,但訓(xùn)練正常進(jìn)行下去,生成器生成的圖像能力和鑒別器鑒別的能力都會(huì)上升。雖然從Loss上看,它們一直在波動(dòng)并難以降低,但它們的能力有時(shí)候已經(jīng)超過了人。(此案例中,生成器Loss和鑒別器Loss有點(diǎn)互斥的感覺,一個(gè)低,那么另一個(gè)就必然會(huì)高,兩者Loss曲線似乎永遠(yuǎn)難以同時(shí)處于低值。)
使用MNIST手寫數(shù)據(jù)集介紹gan的全過程
加載環(huán)境并下載MNIST數(shù)據(jù)集
%matplotlib inline
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transformsnum_workers = 0
batch_size = 64transform = transforms.ToTensor()train_data = datasets.MNIST(root='data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,num_workers=num_workers)
可視化數(shù)據(jù)
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()img = np.squeeze(images[0])fig = plt.figure(figsize = (3,3))
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
定義gan模型
gan由兩個(gè)網(wǎng)絡(luò)組成:一個(gè)鑒別器網(wǎng)絡(luò)、一個(gè)生成器網(wǎng)絡(luò)。網(wǎng)絡(luò)結(jié)構(gòu)圖如下:
此案例中,生成器和鑒別器都是用全連接層來搭建:
- 生成器輸入的是一個(gè)28x28的隨機(jī)矩陣,取值在(-1,1),輸出是一個(gè)一維向量,有784個(gè)值,并且取值也在(-1,1)之間,因?yàn)樽詈笠粋€(gè)全連接層用的tanh激勵(lì)函數(shù),輸出值會(huì)控制在(-1,1)之間。當(dāng)然生成器訓(xùn)練好后,把這個(gè)784的向量拉成28x28也就是一張偽造的手寫圖了。
- 鑒定器輸入的也是一個(gè)28x28的圖像,可能是生成器捏造出的圖像,也可能是真實(shí)MNIST圖像,輸出是一個(gè)浮點(diǎn)數(shù)。當(dāng)鑒定器訓(xùn)練好后,這個(gè)float點(diǎn)數(shù)大于0,則表示鑒定器認(rèn)為輸入的圖像是真實(shí)的MNIST圖像,小于0,則表示鑒定器認(rèn)為輸入的圖像是捏造的圖像。
鑒別器的網(wǎng)絡(luò)結(jié)構(gòu)代碼
我們希望鑒別器輸出0~1來表示輸入的圖像到底是真實(shí)圖像,還是捏造的圖像。
不過:后續(xù)我們會(huì)為此gan模型選擇 BCEWithLogitsLoss 損失函數(shù),它是sigmoid激勵(lì)函數(shù)和BCEloss的結(jié)合體,所以我們的鑒別器網(wǎng)絡(luò)輸出,這里先不需要加sigmoid。
import torch.nn as nn
import torch.nn.functional as Fclass Discriminator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_dim*4)self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)self.fc4 = nn.Linear(hidden_dim, output_size)self.dropout = nn.Dropout(0.3)def forward(self, x):x = x.view(-1, 28*28)x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)x = self.dropout(x)x = F.leaky_relu(self.fc2(x), 0.2)x = self.dropout(x)x = F.leaky_relu(self.fc3(x), 0.2)x = self.dropout(x)out = self.fc4(x)return out
生成器的網(wǎng)絡(luò)結(jié)構(gòu)代碼
class Generator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):super(Generator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)self.fc4 = nn.Linear(hidden_dim*4, output_size)self.dropout = nn.Dropout(0.3)def forward(self, x):x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)x = self.dropout(x)x = F.leaky_relu(self.fc2(x), 0.2)x = self.dropout(x)x = F.leaky_relu(self.fc3(x), 0.2)x = self.dropout(x)out = F.tanh(self.fc4(x))return out
【核心】鑒別器和生成器如何訓(xùn)練?
它們兩個(gè)的訓(xùn)練其實(shí)很簡(jiǎn)單,又很機(jī)智。兩個(gè)網(wǎng)絡(luò)是分開訓(xùn)練的,但是需要同時(shí)訓(xùn)練,因?yàn)殍b別器的損失計(jì)算需要用到生成器生成的圖像,而生成器的損失計(jì)算也需要鑒別器預(yù)測(cè)的結(jié)果。
鑒別器的訓(xùn)練過程:
- 抽取1張real圖像,鑒定器去判定是真圖還是假圖,計(jì)算損失d_real_loss。
- 給生成器輸入一個(gè)隨機(jī)的28x28的矩陣,生成器網(wǎng)絡(luò)生成一個(gè)新28x28圖像,把這個(gè)fake圖像輸入鑒定器,它去判定是真圖還是假圖,計(jì)算損失d_fake_loss。
- 鑒別器本次訓(xùn)練的總損失:d_loss = d_real_loss + d_fake_loss
- 更新一次鑒別器網(wǎng)絡(luò)參數(shù)。
生成器的訓(xùn)練過程:
- (緊接著上述第4步)生成器再次生成1張fake圖,然后把這個(gè)fake圖輸入鑒別器網(wǎng)絡(luò),根據(jù)鑒別器的結(jié)果來計(jì)算出生成器本次的損失。
- 更新一次生成器網(wǎng)絡(luò)參數(shù)。
損失函數(shù)
# Calculate losses
# 以下兩個(gè)函數(shù),唯一區(qū)別是real_loss使用了【標(biāo)簽平滑】技術(shù)。
def real_loss(D_out, smooth=False):batch_size = D_out.size(0)# label smoothingif smooth:# smooth, real labels = 0.9labels = torch.ones(batch_size)*0.9 # 采用【標(biāo)簽平滑】訓(xùn)練技巧(因?yàn)檎鎸?shí)圖像太容易學(xué)會(huì),導(dǎo)致過早停止學(xué)習(xí))else:labels = torch.ones(batch_size) # real labels = 1# numerically stable losscriterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return lossdef fake_loss(D_out):batch_size = D_out.size(0)labels = torch.zeros(batch_size) # fake labels = 0criterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return loss
訓(xùn)練代碼
import torch.optim as optim
lr = 0.002
d_optimizer = optim.Adam(D.parameters(), lr)
g_optimizer = optim.Adam(G.parameters(), lr)# Discriminator hyperparams
# Size of input image to discriminator (28*28)
input_size = 784
# Size of discriminator output (real or fake)
d_output_size = 1
# Size of last hidden layer in the discriminator
d_hidden_size = 32# Generator hyperparams
# Size of latent vector to give to generator
z_size = 100
# Size of discriminator output (generated image)
g_output_size = 784
# Size of first hidden layer in the generator
g_hidden_size = 32import pickle as pklnum_epochs = 30# keep track of loss and generated, "fake" samples
samples = [] #保存每個(gè)epoch后,生成器生成的樣本效果圖。
losses = [] #保存每個(gè)epoch的loss值。# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()# train the network
D.train()
G.train()
for epoch in range(num_epochs):for batch_i, (real_images, _) in enumerate(train_loader):batch_size = real_images.size(0)## Important rescaling step ## real_images = real_images*2 - 1 # rescale input images from [0,1) to [-1, 1)# ============================================# TRAIN THE DISCRIMINATOR# ============================================d_optimizer.zero_grad()# 1. Train with real images# Compute the discriminator losses on real images # smooth the real labelsD_real = D(real_images)d_real_loss = real_loss(D_real, smooth=True)# 2. Train with fake images# Generate fake imagesz = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()fake_images = G(z)# Compute the discriminator losses on fake images D_fake = D(fake_images)d_fake_loss = fake_loss(D_fake)# add up loss and perform backpropd_loss = d_real_loss + d_fake_lossd_loss.backward()d_optimizer.step()# =========================================# TRAIN THE GENERATOR# =========================================g_optimizer.zero_grad()# 1. Train with fake images and flipped labels# Generate fake imagesz = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()fake_images = G(z)# Compute the discriminator losses on fake images # using flipped labels!D_fake = D(fake_images)g_loss = real_loss(D_fake) # use real loss to flip labels# perform backpropg_loss.backward()g_optimizer.step()print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))## AFTER EACH EPOCH### append discriminator loss and generator losslosses.append((d_loss.item(), g_loss.item()))#每訓(xùn)練一個(gè)epoch,測(cè)試生成器生成圖像的情況,并保存生成的結(jié)果# generate and save sample, fake imagesG.eval() # eval mode for generating samplessamples_z = G(fixed_z) samples.append(samples_z)G.train() # back to train mode# Save training generator samples
with open('train_samples.pkl', 'wb') as f: #將生成器每個(gè)epoch的生成效果圖保存到pkl文件中。pkl.dump(samples, f)
30個(gè)epoch,loss圖如下:
從上圖可看出,loss很難下降,而且波動(dòng)劇烈。但是實(shí)際上,生成器loss和鑒別器loss是一種相反關(guān)系,即鑒別器牛逼,那么生成器就很菜,它們loss會(huì)一個(gè)高一個(gè)低,這種情況,生成器就更大幅度的梯度下降,不要多久效果就超過鑒別器,導(dǎo)致它們的loss變反,后面鑒別器又會(huì)加速訓(xùn)練。。。
訓(xùn)練100個(gè)epoch圖也差不多,兩者從loss上并不會(huì)收斂:(忽略起始loss)
可視化生成器每個(gè)epoch后生成的效果
# Load samples from generator, taken while training
with open('train_samples.pkl', 'rb') as f:samples = pkl.load(f)rows = 30
cols = 16 # 每行顯示幾個(gè)生成圖(注意:當(dāng)初一個(gè)epoch只生成了16個(gè)樣本,這里最大16)
fig, axes = plt.subplots(figsize=(14,28), nrows=rows, ncols=cols, sharex=True, sharey=True)for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):img = img.detach()ax.imshow(img.reshape((28,28)), cmap='Greys_r')ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)
要知道,輸入生成器的矩陣永遠(yuǎn)是隨機(jī)的28x28的矩陣,長得像這樣:
從下圖可看出,經(jīng)過一個(gè)epoch后,生成器已經(jīng)知道要在圖像中間形成一堆‘白色點(diǎn)’,在圖像周圍要‘變黑’。
再經(jīng)過一些epoch后,開始學(xué)會(huì)捏造一些數(shù)字!
測(cè)試生成器效果
# helper function for viewing a list of passed in sample images
def view_samples(epoch, samples):fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples[epoch]):img = img.detach()ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')# randomly generated, new latent vectors
sample_size=16
rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
rand_z = torch.from_numpy(rand_z).float()G.eval() # eval mode
# generated samples
rand_images = G(rand_z)# 0 indicates the first set of samples in the passed in list
# and we only have one batch of samples, here
view_samples(0, [rand_images])
總結(jié)
以上是生活随笔為你收集整理的生成式对抗网络的原理和实现方法的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 什么是self-attention、Mu
- 下一篇: GitHub上传代码、更新代码、toke