理解和创建GANs|使用PyTorch来做深度学习
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達作者:Venkatesh Tata
編譯:ronghuaiyang
生成對抗網(wǎng)絡(luò)的一篇實踐文章,使用PyTorch,用很簡單的代碼搭建了一個GANs,非常通俗易懂。
我們創(chuàng)建了一個生成對抗網(wǎng)絡(luò),可以生成顯示世界中沒有的鳥。
這些鳥都是通過GANs生成的。
在我們實際創(chuàng)建GAN之前,我們先看看GANs背后的思想。GANs是Ian Goodfellow發(fā)明的,他在斯坦福獲得了本科和碩士學位,在蒙特利爾大學獲得了博士學位。這是深度學習領(lǐng)域的一個新的大事。Yann LeCun說過:
"生成對抗網(wǎng)絡(luò)是近年來機器學習領(lǐng)域最有趣的想法"
什么是GANs?我們?yōu)槭裁匆獎?chuàng)造GANs?
神經(jīng)網(wǎng)絡(luò)很擅長分類和預測事情,但是AI的研究者想要讓神經(jīng)網(wǎng)絡(luò)更加像人類,通過創(chuàng)造東西而不僅僅是看見東西。 Ian Goodfellow成功的發(fā)明了這樣一類深度學習模型,可以用來創(chuàng)造東西。
GANs是怎么工作的?
GANs有兩個獨立的神經(jīng)網(wǎng)絡(luò)。一個叫做“G”,代表了生成器,另外一個叫做“D”,代表了判別器。生成器首先隨機的產(chǎn)生圖像,判別器通過觀察這些圖像告訴生成器這些圖片有多真實。
讓我們考慮一個生成器
在開始的時候,生成器用一個隨機噪聲信號作為輸入,產(chǎn)生一個隨機圖像作為輸出,通過判別器的幫助,開始產(chǎn)生越來越真實的圖像。
判別器
判別器是生成器的一個對手,它的輸入即有真實的圖像,同時也有生成器生成的圖像,判別器輸出這個圖像的真實程度。
到了某個點的時候,判別器無法判斷出這個圖像是否是真實圖像了,這時我們可以發(fā)現(xiàn)某個由生成器輸出的圖像是之前從沒有存在過的了。
GANs的應用
超分辨率
藝術(shù)輔助
元素抽取
開始寫代碼 !
注意:下面的代碼并不適合深度學習的新手,我希望你有一些python深度學習的經(jīng)驗。
開始我們先導入一些GAN需要的包。首先需要確保PyTorch已安裝。
#importing required librariesfrom __future__ import print_functionimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.optim as optimimport torch.utils.dataimport torchvision.datasets as dsetimport torchvision.transforms as transformsimport torchvision.utils as vutilsfrom torch.autograd import Variable設(shè)置一些超參數(shù),batch-size和圖像的尺寸:
# Setting hyperparametersbatchSize = 64 imageSize = 64第一行我們設(shè)置了batchsize為64,第二行設(shè)置了輸出圖像的尺寸為64x64。
然后我們創(chuàng)建一個圖像的轉(zhuǎn)換器的對象,如下:
# Creating the transformationstransform = transforms.Compose([transforms.Scale(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])上面的轉(zhuǎn)化器是將圖像作為判別器的輸入所必須的。
注意:如果需要獲取數(shù)據(jù)集,點擊這里:https://github.com/venkateshtata/GAN_Medium.git>,clone這個倉庫,然后替換 “dcgan.py” 文件為你需要寫入的python文件, “data” 文件夾存儲的是數(shù)據(jù)集。
現(xiàn)在我們加載數(shù)據(jù)集。這里我們使用的是 CIFAR-10的數(shù)據(jù)集。我們批量加載,確保你的python文件和你導入的數(shù)據(jù)集在同一個文件夾。
# Loading the datasetdataset = dset.CIFAR10(root = './data', download = True, transform = transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2)我們將數(shù)據(jù)集下載后放在./data目錄下,應用我們之前定義的轉(zhuǎn)化器。然后使用dataLoader 來獲取訓練圖像。其中‘num_workers’ 表示的是讀取數(shù)據(jù)用的線程的數(shù)量,其他的參數(shù)可以從字面意思理解。
由于這里我們需要處理兩個神經(jīng)網(wǎng)絡(luò),我們會定義一個全局的函數(shù)來初始化給定的神經(jīng)網(wǎng)絡(luò),只要將神經(jīng)網(wǎng)絡(luò)模型通過參數(shù)傳給這個函數(shù)即可。
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:m.weight.data.normal_(0.0, 0.02)elif classname.find('BatchNorm') != -1:m.weight.data.normal_(1.0, 0.02)m.bias.data.fill_(0)上面的函數(shù)獲取神經(jīng)網(wǎng)絡(luò)的模型作為參數(shù),初始化所有的參數(shù)。這個函數(shù)在訓練開始時在每個迭代都會調(diào)用。
第一步就是定義我們的生成器神經(jīng)網(wǎng)絡(luò)。我們創(chuàng)建一個生成器的類,里面包含了一系列的層。
分解上面的代碼:
我們創(chuàng)建了一個類‘G’,繼承了 ‘nn.module’,這個類里有構(gòu)建模型所需要的各種功能,只要將各種應用和連接放到神經(jīng)網(wǎng)絡(luò)里即可。
然后我們創(chuàng)建了一個模型,包含了一系列的模塊,如卷積,全連接等。
這里從圖中可以看大,生成器和判別器是相互倒著的。生成器的輸入時一個向量,所以這里我們使用了轉(zhuǎn)置卷積 ‘ConvTranspose2d’。
然后我們在batch的維度上對所有的特征進行了歸一化,然后使用ReLU進行了非線性變換。
我們重復上面的操作,輸入的節(jié)點從100變到了512,特征數(shù)從512變到了256,bias保持為False。
在最后的 ‘ConvTranspose2d’ 中,我們輸出了3個通道,因為輸出的是‘RGB’的圖像,使用了‘Tanh’作為激活函數(shù)。
現(xiàn)在我們創(chuàng)建一個forward函數(shù)來進行生成器信號的前向傳播。
def forward(self, input):output = self.main(input)return output上面的函數(shù)的輸入時長度為100的隨機向量。返回的是一個生成的圖像。隨機向量產(chǎn)生隨機圖像。
創(chuàng)建生成器:
netG = G() netG.apply(weights_init)這里我們創(chuàng)建了一個生成器,然后進行了參數(shù)初始化。
現(xiàn)在我們再定義一個判別器類:
class D(nn.Module):def __init__(self):super(D, self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1, bias = False),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(64, 128, 4, 2, 1, bias = False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(128, 256, 4, 2, 1, bias = False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(256, 512, 4, 2, 1, bias = False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(512, 1, 4, 1, 0, bias = False),nn.Sigmoid())判別器分解:
和G類似,判別器也是繼承了‘nn.module’,輸入是生成器生成的圖像,返回一個0~1之間的數(shù)字。
由于用生成器的輸出作為輸入,第一個操作時卷積,我們的激活函數(shù)使用了LeakyReLU。
可以看到,不同于生成器,我們這里使用了LeakyReLU,這個是經(jīng)驗得來的。
我們使用了‘BatchNorm2d’ 來進行特征歸一化。
最后,我們使用了sigmoid函數(shù),輸入0~1之間的概率。
為了進行前向傳播,我們定義一個forward函數(shù),使用生成器的輸出作為輸入:
def forward(self, input):output = self.main(input)return output.view(-1)最后一行,我們的輸出值在0~1之間,由于我們需要把向量鋪平,確保向量有相同的維度。
創(chuàng)建判別器 :
netD = D() netD.apply(weights_init)上面我們創(chuàng)建了判別器,初始化所有的參數(shù):
現(xiàn)在,我們開始訓練生成對抗網(wǎng)絡(luò)。開始之前,我們需要得到一個損失函數(shù),用來評價判別器的損失。我們使用 BCE Loss,非常適合對抗網(wǎng)絡(luò)。然后生成器和判別器我們都需要一個優(yōu)化器。
criterion = nn.BCELoss()optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))我們創(chuàng)建了一個評價函數(shù)用來度量預測和目標之間的差別。我們?yōu)榕袆e器和生成器各創(chuàng)建了一個優(yōu)化器。
我們使用了 ‘Adam’ 優(yōu)化器,這是個SGD的升級版。
我們訓練神經(jīng)網(wǎng)絡(luò)25個epochs:
for epoch in range(25):我們從數(shù)據(jù)集中循環(huán)讀取圖像 :
for i, data in enumerate(dataloader, 0):第一步需要更新判別器中的參數(shù),我們把判別器中所有的梯度清零。
netD.zero_grad()我們知道,判別器需要用真實和虛假的圖像同時訓練。這里我們先用一個真實圖像來訓練
real, _ = datainput = Variable(real)target = Variable(torch.ones(input.size()[0]))output = netD(input)errD_real = criterion(output, target)我們從數(shù)據(jù)集中獲取一個真實圖像訓練判別器,然后包裝成一個變量。然后前向傳播,得到預測值,然后計算loss。
現(xiàn)在,使用生成器輸出的虛假圖像訓練判別器:
noise = Variable(torch.randn(input.size()[0], 100, 1, 1))fake = netG(noise)target = Variable(torch.zeros(input.size()[0]))output = netD(fake.detach())errD_fake = criterion(output, target)這里,我們先讓一個隨機向量通過生成器,得到一個虛假的圖像。然后將這個虛假圖像通過判別器,得到預測,計算損失。
誤差反向傳播:
errD = errD_real + errD_fakeerrD.backward()optimizerD.step()這里我們計算判別器總的loss作為判別器的loss,更新判別器的時候,不更新生成器的權(quán)值。最后我們通過優(yōu)化器來判別器更新權(quán)值。
下面我們更新生成器的權(quán)值:
netG.zero_grad()target = Variable(torch.ones(input.size()[0]))output = netD(fake)errG = criterion(output, target)errG.backward()optimizerG.step()就像之前一樣,我們先將所有的梯度清零。然后將loss是通過計算生成器的梯度來反向傳播,然后通過生成器的優(yōu)化器來更新生成器的權(quán)值。
現(xiàn)在,我們最后的步驟就是在每100個steps時打印loss,存儲真實的圖像和生成的圖像,可以這么做:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 25, i, len(dataloader), errD.data[0], errG.data[0]))if i % 100 == 0:vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)fake = netG(noise)vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)完整代碼 :
from __future__ import print_functionimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.optim as optimimport torch.utils.dataimport torchvision.datasets as dsetimport torchvision.transforms as transformsimport torchvision.utils as vutilsfrom torch.autograd import VariablebatchSize = 64 imageSize = 64transform = transforms.Compose([transforms.Scale(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.dataset = dset.CIFAR10(root = './data', download = True, transform = transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:m.weight.data.normal_(0.0, 0.02)elif classname.find('BatchNorm') != -1:m.weight.data.normal_(1.0, 0.02)m.bias.data.fill_(0)class G(nn.Module):def __init__(self):super(G, self).__init__()self.main = nn.Sequential(nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),nn.BatchNorm2d(512),nn.ReLU(True),nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),nn.Tanh())def forward(self, input):output = self.main(input)return outputnetG = G()netG.apply(weights_init)class D(nn.Module):def __init__(self):super(D, self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1, bias = False),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(64, 128, 4, 2, 1, bias = False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(128, 256, 4, 2, 1, bias = False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(256, 512, 4, 2, 1, bias = False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace = True),nn.Conv2d(512, 1, 4, 1, 0, bias = False),nn.Sigmoid())def forward(self, input):output = self.main(input)return output.view(-1)netD = D()netD.apply(weights_init)criterion = nn.BCELoss()optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))for epoch in range(25):for i, data in enumerate(dataloader, 0):netD.zero_grad()real, _ = datainput = Variable(real)target = Variable(torch.ones(input.size()[0]))output = netD(input)errD_real = criterion(output, target)noise = Variable(torch.randn(input.size()[0], 100, 1, 1))fake = netG(noise)target = Variable(torch.zeros(input.size()[0]))output = netD(fake.detach())errD_fake = criterion(output, target)errD = errD_real + errD_fakeerrD.backward()optimizerD.step()netG.zero_grad()target = Variable(torch.ones(input.size()[0]))output = netD(fake)errG = criterion(output, target)errG.backward()optimizerG.step()print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 25, i, len(dataloader), errD.data[0], errG.data[0]))if i % 100 == 0:vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)fake = netG(noise)vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)你可以從我的GitHub倉庫看到代碼:
https://github.com/venkateshtata/GAN_Medium
如果有好的建議,可以隨便fork或者拉代碼,謝謝!
好消息!
小白學視覺知識星球
開始面向外開放啦👇👇👇
下載1:OpenCV-Contrib擴展模塊中文版教程在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。交流群歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~總結(jié)
以上是生活随笔為你收集整理的理解和创建GANs|使用PyTorch来做深度学习的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: (附源码)计算机毕业设计SSM健身房管理
- 下一篇: 周末读fastclick.js源码有感