基于Jittor框架实现LSGAN图像生成对抗网络
基于Jittor框架實現LSGAN圖像生成對抗網絡
生成對抗網絡(GAN, Generative Adversarial Networks )是一種深度學習模型,是近年來復雜分布上無監督學習最具前景的方法之一。GAN模型由生成器(Generator)和判別器(Discriminator)兩個部分組成。在訓練過程中,生成器的目標就是盡量生成真實的圖片去欺騙判別器。而判別器的目標就是盡量把生成器生成的圖片和真實的圖片分別開來。這樣,生成器和判別器構成了一個動態的“博弈過程”。許多相關的研究工作表明GAN能夠產生效果非常真實的生成效果。
使用Jittor框架實現了一種經典GAN模型LSGAN。LSGAN將GAN的目標函數由交叉熵損失替換成最小二乘損失,以此拒絕了標準GAN生成的圖片質量不高以及訓練過程不穩定這兩個缺陷。通過LSGAN的實現介紹了Jittor數據加載、模型定義、模型訓練的使用方法。
LSGAN論文:https://arxiv.org/abs/1611.04076
1.數據集準備
使用兩種數據集進行LSGAN的訓練,分別是Jittor自帶的數據集MNIST,和用戶構建的數據集CelebA。您可以通過以下鏈接下載CelebA數據集。
? CelebA 數據集: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
使用Jittor自帶的MNIST數據加載器方法如下。使用jittor.transform可以進行數據歸一化及數據增強,這里通過transform將圖片歸一化到[0,1]區間,并resize到標準大小112*112。。通過set_attrs函數可以修改數據集的相關參數,如batch_size、shuffle及transform等。
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5], std=[0.5])
])
train_loader = MNIST (train=True, transform=transform)
.set_attrs(batch_size=batch_size, shuffle=True)
val_loader = MNIST (train=False, transform = transform)
.set_attrs(batch_size=1, shuffle=True)
使用用戶構建的CelebA數據集方法如下,通過通用數據加載器jittor.dataset.dataset.ImageFolder,輸入數據集路徑即可構建用戶數據集。
from jittor.dataset.dataset import ImageFolder
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dir = ‘./data/celebA_train’
train_loader = ImageFolder(train_dir)
.set_attrs(transform=transform, batch_size=batch_size, shuffle=True)
val_dir = ‘./data/celebA_eval’
val_loader = ImageFolder(val_dir)
.set_attrs(transform=transform, batch_size=1, shuffle=True)
2.模型定義
2.1.網絡結構
使用LSGAN進行圖像生成,下圖為LSGAN論文給出的網絡架構圖,其中(a)為生成器,(b)為判別器。生成器網絡輸入一個1024維的向量,生成分辨率為112112的圖像;判別器網絡輸入112112的圖像,輸出一個數字表示輸入圖像為真實圖像的可信程度。
受到VGG模型的啟發,生成器在與DCGAN的結構基礎上在前兩個反卷積層之后增加了兩個步長=1的反卷積層。除使用最小二乘損失函數外判別器的結構與DCGAN中的結構相同。與DCGAN相同,生成器和判別器分別使用了ReLU激活函數和LeakyReLU激活函數。
下面將介紹如何使用Jittor定義一個網絡模型。定義模型需要繼承基類jittor.Module,并實現__init__和execute函數。__init__函數在模型聲明時會被調用,用于進行模型內部op或其他模型的聲明及參數的初始化。該模型初始化時輸入參數dim表示訓練圖像的通道數,對于MNIST數據集dim為1,對于CelebA數據集dim為3。
execute函數在網絡前向傳播時會被調用,用于定義前向傳播的計算圖,通過autograd機制在訓練時Jittor會自動構建反向計算圖。
import jittor as jt
from jittor import nn, Module
class generator(Module):
def init(self, dim=3):
super(generator, self).init()
self.fc = nn.Linear(1024, 77256)
self.fc_bn = nn.BatchNorm(256)
self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
self.deconv1_bn = nn.BatchNorm(256)
self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1)
self.deconv2_bn = nn.BatchNorm(256)
self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
self.deconv3_bn = nn.BatchNorm(256)
self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1)
self.deconv4_bn = nn.BatchNorm(256)
self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1)
self.deconv5_bn = nn.BatchNorm(128)
self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1)
self.deconv6_bn = nn.BatchNorm(64)
self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def execute(self, input):x = self.fc_bn(self.fc(input).reshape((input.shape[0], 256, 7, 7)))x = self.relu(self.deconv1_bn(self.deconv1(x)))x = self.relu(self.deconv2_bn(self.deconv2(x)))x = self.relu(self.deconv3_bn(self.deconv3(x)))x = self.relu(self.deconv4_bn(self.deconv4(x)))x = self.relu(self.deconv5_bn(self.deconv5(x)))x = self.relu(self.deconv6_bn(self.deconv6(x)))x = self.tanh(self.deconv7(x))return x
class discriminator(nn.Module):
def init(self, dim=3):
super(discriminator, self).init()
self.conv1 = nn.Conv(dim, 64, 5, 2, 2)
self.conv2 = nn.Conv(64, 128, 5, 2, 2)
self.conv2_bn = nn.BatchNorm(128)
self.conv3 = nn.Conv(128, 256, 5, 2, 2)
self.conv3_bn = nn.BatchNorm(256)
self.conv4 = nn.Conv(256, 512, 5, 2, 2)
self.conv4_bn = nn.BatchNorm(512)
self.fc = nn.Linear(51277, 1)
self.leaky_relu = nn.Leaky_relu()
def execute(self, input):x = self.leaky_relu(self.conv1(input), 0.2)x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)x = x.reshape((x.shape[0], 512*7*7))x = self.fc(x)return x
2.2.損失函數
損失函數采用最小二乘損失函數,其中判別器損失函數如下。其中x為真實圖像,z為服從正態分布的1024維向量,a取值為1,b取值為0。
生成器損失函數如下。其中z為服從正態分布的1024維向量,c取值為1。
具體實現如下,x為生成器的輸出值,b表示該圖像是否希望被判別為真。
def ls_loss(x, b):
mini_batch = x.shape[0]
y_real_ = jt.ones((mini_batch,))
y_fake_ = jt.zeros((mini_batch,))
if b:
return (x-y_real_).sqr().mean()
else:
return (x-y_fake_).sqr().mean()
3.模型訓練
3.1.參數設定
參數設定如下。
通過use_cuda設置在GPU上進行訓練
jt.flags.use_cuda = 1
批大小
batch_size = 128
學習率
lr = 0.0002
訓練輪數
train_epoch = 50
訓練圖像標準大小
img_size = 112
Adam優化器參數
betas = (0.5,0.999)
數據集圖像通道數,MNIST為1,CelebA為3
dim = 1 if task==“MNIST” else 3
3.2.模型、優化器聲明
分別聲明生成器和判別器,并使用Adam作為優化器。
生成器
G = generator (dim)
判別器
D = discriminator (dim)
生成器優化器
G_optim = nn.Adam(G.parameters(), lr, betas=betas)
判別器優化器
D_optim = nn.Adam(D.parameters(), lr, betas=betas)
3.3.訓練
for epoch in range(train_epoch):
for batch_idx, (x_, target) in enumerate(train_loader):
mini_batch = x_.shape[0]
# 判別器訓練
D_result = D(sx)
D_real_loss = ls_loss(D_result, True)
z_ = init.gauss((mini_batch, 1024), ‘float’)
G_result = G(z_)
D_result_ = D(G_result)
D_fake_loss = ls_loss(D_result_, False)
D_train_loss = D_real_loss + D_fake_loss
D_train_loss.sync()
D_optim.step(D_train_loss)
# 生成器訓練z_ = init.gauss((mini_batch, 1024), 'float')G_result = G(z_)D_result = D(G_result)G_train_loss = ls_loss(D_result, True)G_train_loss.sync()G_optim.step(G_train_loss)if (batch_idx%100==0):print('D training loss =', D_train_loss.data.mean())print('G training loss =', G_train_loss.data.mean())
4.結果與測試
4.1.生成結果
分別使用MNIST和CelebA數據集進行了50個epoch的訓練。訓練完成后各隨機采樣了25張圖像,結果如下。
4.2.速度對比
使用Jittor與主流的深度學習框架PyTorch進行了訓練速度的對比,下表為PyTorch(是/否打開benchmark)及Jittor在兩種數據集上進行1次訓練迭帶的使用時間。得益于Jittor特有的元算子融合技術,其訓練速度比PyTorch快了40%~55%。
總結
以上是生活随笔為你收集整理的基于Jittor框架实现LSGAN图像生成对抗网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Angel图算法
- 下一篇: iOS视频硬编码技术