PyTorch 实现 GAN 生成式对抗网络 含代码
GAN
- 網絡結構
- GAN 公式的理解
- 簡單線性 GAN 代碼如下
- 卷積 GAN 代碼如下
- Ref
網絡結構
GAN 公式的理解
minGmaxDV(D,G)=Ex~Pdata(x)[logD(x)]+Ez~Pz(z)[log(1?D(G(z)))]min_Gmax_D V(D,G) = E_{x\sim P_{data}(x)} [logD(x)] + E_{z\sim P_{z}(z)}[log(1-D(G(z)))]minG?maxD?V(D,G)=Ex~Pdata?(x)?[logD(x)]+Ez~Pz?(z)?[log(1?D(G(z)))]
理解 GAN 公式是進一步理解 GAN 的必經過程,所以下面就來簡單講講該公式。一開始我們要定義出判別器和生成器,這里將 DDD 定義為判別器,將 GGG 定義成生成器。接著要做的就是訓練判別器,讓它可以識別真實數據,也就有了 GAN 公式的前半部分。
Ex~Pdata(x)[logD(x)]E_{x\sim P_{data}(x)}[logD(x)]Ex~Pdata?(x)?[logD(x)]
其中,Ex~Pdata(x)E_{x\sim P_{data}(x)}Ex~Pdata?(x)? 表示期望 xxx 從 PdataP_{data}Pdata? 分布中獲取;xxx 表示真實數據, PdataP_{data}Pdata? 表示真實數據的分布。
前半部分的意思就是:判別器判別出真實數據的概率,判別器的目的就是要最大化這一項,簡單來說,就是對于服從 PdataP_{data}Pdata? 分布的 xxx,判別器可以準確得出 D(x)≈1D(x)\approx 1D(x)≈1。
接著看 GAN 公式略微復雜的后半部分。
Ez~Pz(z)[log(1?D(G(z)))]E_{z\sim P_z(z)} [log(1-D(G(z)))]Ez~Pz?(z)?[log(1?D(G(z)))]
其中,Ez~Pz(z)E_{z\sim P_z(z)}Ez~Pz?(z)? 表示期望 zzz 是從 Pz(z)P_z(z)Pz?(z) 分布中獲取;zzz 表示生成數據;Pz(z)P_z(z)Pz?(z) 表示生成數據的分布。
對于判別器 DDD 而言,如果向其輸入的是生成數據,即 D(G(z))D(G(z))D(G(z)),判別器的目標就是最小化 D(G(z))D(G(z))D(G(z)),即判別器希望 D(G(z))≈0D(G(z))\approx 0D(G(z))≈0,也就是判別器希望 log(1?D(G(z)))log(1-D(G(z)))log(1?D(G(z))) 最大化。
但對生成器來說,它的目標卻與判別器相反,生成器希望自己生成的數據被判別器打上高分,即希望 D(G(z))≈1D(G(z))\approx 1D(G(z))≈1,也就是最小化 log(1?D(G(z)))log(1-D(G(z)))log(1?D(G(z)))。生成器只能影響 GAN 公式的后半部分,對前半部分沒有影響。
現在可以理解公式 V(D,G)=Ex~Pdata(x)[logD(x)]+Ez~Pz(z)[log(1?D(G(z)))]V(D,G) = E_{x\sim P_{data}(x)}[logD(x)] + E_{z\sim P_z(z)}[log(1-D(G(z)))]V(D,G)=Ex~Pdata?(x)?[logD(x)]+Ez~Pz?(z)?[log(1?D(G(z)))],但為什么 GAN 公式中還有 minGmaxDmin_Gmax_DminG?maxD? 呢?
要理解 minGmaxDmin_Gmax_DminG?maxD?,就要先回憶一下 GAN 的訓練流程。一開始,固定生成器 GGG 的參數專門去訓練判別器 DDD。GAN 公式表達的意思也一樣,先針對判別器 DDD 去訓練,也就是最大化 D(x)D(x)D(x) 和 log(1?D(G(z)))log(1-D(G(z)))log(1?D(G(z))) 的值,從而達到最大化 V(D,G)V(D,G)V(D,G) 的目的,表達如下:
DG?=argmaxDV(D,G)D_G^\star = argmax_D V(D,G)DG??=argmaxD?V(D,G)
當訓練完判別器 DDD 后,就會固定判別器 DDD 的參數去訓練生成器 GGG,因為此時判別器已經經過一次訓練了,所以生成器 GGG 的目標就變成:當 D=DG?D=D_G^\starD=DG?? 時,最小化 log(1?D(G(z)))log(1-D(G(z)))log(1?D(G(z))) 的值,從而達到最小化 V(D,G)V(D,G)V(D,G)的目的。表達如下:
G?=argminGV(G,DG?)G^\star = argmin_G V(G,D_G^\star)G?=argminG?V(G,DG??)
通過上面分成兩步的分析,我們可以理解 minGmaxDmin_Gmax_DminG?maxD? 的含義,簡單來說,就是先從判別器 DDD 的角度最大化 V(D,G)V(D,G)V(D,G),再從生成器 GGG 的角度最小化 V(D,G)V(D,G)V(D,G)。
上邊公式講解中,大量使用對數,對數函數在它的定義域內是單調增函數,數據取對數后,并不會改變數據間的相對關系,這里使用對數是為了讓計算更加方便。
Ref:《深入淺出GAN生成對抗網絡》-廖茂文
簡單線性 GAN 代碼如下
import torch import torchvision import torch.nn as nn import torch.nn.functional as F from torchvision import datasets from torchvision import transforms from torchvision.utils import save_image from torch.autograd import Variable import osif not os.path.exists('./img'):os.mkdir('./img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128 num_epoch = 100 z_dimension = 100# Image processing img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5)) ]) # MNIST dataset mnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True) # Data loader dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)# Discriminator class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())def forward(self, x):x = self.dis(x)return x# Generator class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())def forward(self, x):x = self.gen(x)return xD = discriminator() G = generator() if torch.cuda.is_available():D = D.cuda()G = G.cuda() # Binary cross entropy loss and optimizer criterion = nn.BCELoss() d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# Start training for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorimg = img.view(num_img, -1)real_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), './generator.pth') torch.save(D.state_dict(), './discriminator.pth')卷積 GAN 代碼如下
__author__ = 'ShelockLiao'import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision import datasets from torchvision.utils import save_image import osif not os.path.exists('./dc_img'):os.mkdir('./dc_img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128 num_epoch = 100 z_dimension = 100 # noise dimensionimg_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5), (0.5)) ])mnist = datasets.MNIST('./data', transform=img_transform) dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,num_workers=4)class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7)self.fc = nn.Sequential(nn.Linear(64*7*7, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return xclass generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56nn.BatchNorm2d(50),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56nn.BatchNorm2d(25),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.br(x)x = self.downsample1(x)x = self.downsample2(x)x = self.downsample3(x)return xD = discriminator().cuda() # discriminator model G = generator(z_dimension, 3136).cuda() # generator modelcriterion = nn.BCELoss() # binary cross entropyd_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# train for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorreal_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './dc_img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))torch.save(G.state_dict(), './generatorConv.pth') torch.save(D.state_dict(), './discriminatorConv.pth')Ref
總結
以上是生活随笔為你收集整理的PyTorch 实现 GAN 生成式对抗网络 含代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch 实现 VAE 变分自编码
- 下一篇: PyTorch 实现经典模型1:LeNe