基于MNIST的GANs实现【Pytorch】
生活随笔
收集整理的這篇文章主要介紹了
基于MNIST的GANs实现【Pytorch】
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
簡述
其實是根據(jù)我之前寫的兩個代碼改的。(之前已經(jīng)有過非常詳細(xì)的解釋了,可以去看看)
- 【GANs入門】pytorch-GANs任務(wù)遷移-單個目標(biāo)(數(shù)字的生成)
- 【Gans入門】Pytorch實現(xiàn)Gans代碼詳解【70+代碼】
同時,在結(jié)合了我之前寫的DCGANs的時候,實現(xiàn)的一份代碼
- (深度卷積生成對抗神經(jīng)網(wǎng)絡(luò))DCGANs論文閱讀與實現(xiàn)pytorch
MNIST上選特定的數(shù)值,是根據(jù)下面的這篇文章得到的。
- MNIST選取特定數(shù)值的訓(xùn)練集
之前的代碼上都有非常詳細(xì)的解釋。這里只是基于上面的一點點改進(jìn)而已。就不給出特別詳細(xì)的解釋。但是代碼中任然保留有注釋部分。
圖形演變過程
代碼
import torch import torch.nn as nn import torchvision import torch.utils.data as Data import matplotlib.pyplot as plt import os import shutil import imageio PNGFILE = './png/' if not os.path.exists(PNGFILE):os.mkdir(PNGFILE) else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)# Hyper Parameters BATCH_SIZE = 64 LR_G = 0.0001 # learning rate for generator LR_D = 0.0001 # learning rate for discriminator N_IDEAS = 100 # think of this as number of ideas for generating an art work (Generator) target_num = 0 # target Number EPOCH = 10 # 訓(xùn)練整批數(shù)據(jù)多少次 DOWNLOAD_MNIST = False # 已經(jīng)下載好的話,會自動跳過的 ART_COMPONENTS = 28 * 28# Mnist 手寫數(shù)字class myMNIST(torchvision.datasets.MNIST):def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):super(myMNIST, self).__init__(root,train=train,transform=transform,target_transform=target_transform,download=download)if targetNum != None:self.train_data = self.train_data[self.train_labels == targetNum]self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]self.train_labels = self.train_labels[self.train_labels == targetNum][:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]def __len__(self):if self.train:return self.train_data.shape[0]else:return 10000train_data = myMNIST(root='./mnist/', # 保存或者提取位置train=True, # this is training datatransform=torchvision.transforms.ToTensor(), # 轉(zhuǎn)換 PIL.Image or numpy.ndarray 成# torch.FloatTensor (C x H x W), 訓(xùn)練的時候 normalize 成 [0.0, 1.0] 區(qū)間download=DOWNLOAD_MNIST, # 沒下載就下載, 下載了就不用再下了targetNum=target_num ) print(len(train_data)) # print(train_data.shape)# 訓(xùn)練集丟BATCH_SIZE個, 圖片大小為28*28 train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True # 是否打亂順序 )G = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideasnn.ReLU(), )D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist )# loss & optimizer optimD = torch.optim.Adam(D.parameters(), lr=LR_D) optimG = torch.optim.Adam(G.parameters(), lr=LR_G)label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1) label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0)filePath = []for epoch in range(EPOCH):for step, (images, imagesLabel) in enumerate(train_loader):G_ideas = torch.randn((BATCH_SIZE, N_IDEAS))G_paintings = G(G_ideas)images = images.reshape(BATCH_SIZE, -1)prob_artist0 = D(images) # D try to increase this probprob_artist1 = D(G_paintings)D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))optimD.zero_grad()D_loss.backward(retain_graph=True)optimD.step()optimG.zero_grad()G_loss.backward(retain_graph=True)optimG.step()if step % 20 == 0:plt.cla()picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28))plt.imshow(picture, cmap=plt.cm.gray_r)plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step))filePath.append(PNGFILE + '%d-%d.png' % (epoch, step))generated_images = [] for png_path in filePath:generated_images.append(imageio.imread(png_path)) shutil.rmtree(PNGFILE) imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1)總結(jié)
以上是生活随笔為你收集整理的基于MNIST的GANs实现【Pytorch】的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch生成一个数组
- 下一篇: 风格迁移理论