plt生成固定的colormap_白话生成对抗网络GAN及代码实现
本文主要是個(gè)簡單的筆記,參考資料來自下面三部分
神經(jīng)網(wǎng)絡(luò)一覽
各種神經(jīng)網(wǎng)絡(luò)(全連接前向網(wǎng)絡(luò)、卷積神經(jīng)網(wǎng)絡(luò)、循環(huán)神經(jīng)網(wǎng)絡(luò))的區(qū)別在于具有不同的輸入/輸出形式,比如可以是向量、矩陣或者是向量序列等。
GAN的基本思想
GAN由生成器和判別器組成:
生成器的本質(zhì)也是一個(gè)神經(jīng)網(wǎng)絡(luò),或者說是一個(gè)函數(shù)
如果給定一個(gè)向量可以生成一張漫畫圖片,向量的每一個(gè)維度具有不同含義
判別器的本質(zhì)也是一個(gè)神經(jīng)網(wǎng)絡(luò)
如果給定一張圖片,判別器就會(huì)告訴你這是不是真實(shí)圖片
所以GAN的訓(xùn)練本質(zhì)就是訓(xùn)練兩個(gè)神經(jīng)網(wǎng)絡(luò)。
GAN的工作原理
生成器的目標(biāo)是產(chǎn)生和訓(xùn)練數(shù)據(jù)相似的數(shù)據(jù)(以假亂真的圖片),而判別器的目標(biāo)是辨別真假。
生成器的輸入通常為隨機(jī)噪聲,判別器有兩個(gè)輸入,一個(gè)來自訓(xùn)練數(shù)據(jù)中的真圖片,一個(gè)來自生成器生成的假圖片。
GAN的流程如下圖所示
每一次迭代過程中:
GAN訓(xùn)練的目標(biāo)函數(shù)如下所示
- 判別器想要最大化目標(biāo)函數(shù)使得對(duì)于真實(shí)數(shù)據(jù) D(x) 接近 1,對(duì)于假數(shù)據(jù) D(G(z)) 接近 0
- 生成器想要最小化目標(biāo)函數(shù)使得 D(G(z)) 接近 1,也就是欺騙判別器讓它認(rèn)為假數(shù)據(jù)為真
GAN的實(shí)現(xiàn)
這里采用 MNIST 數(shù)據(jù)集作為實(shí)驗(yàn)數(shù)據(jù),最后我們會(huì)看到生成器能夠產(chǎn)生看起來像真的數(shù)字!
導(dǎo)入需要用到的庫
import numpy as np import pandas as pd import matplotlib.pyplot as plt %matplotlib inline import keras from keras.layers import Dense, Dropout, Input from keras.models import Model,Sequential from keras.datasets import mnist from tqdm import tqdm from keras.layers.advanced_activations import LeakyReLU from keras.optimizers import Adam導(dǎo)入數(shù)據(jù)
def load_data():(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train = (x_train.astype(np.float32) - 127.5)/127.5# 將圖片轉(zhuǎn)為向量 x_train from (60000, 28, 28) to (60000, 784) # 每一行 784 個(gè)元素x_train = x_train.reshape(60000, 784)return (x_train, y_train, x_test, y_test) (X_train, y_train,X_test, y_test)=load_data() print(X_train.shape)定義優(yōu)化器
def adam_optimizer():return Adam(lr=0.0002, beta_1=0.5)這里要采用的生成對(duì)抗網(wǎng)絡(luò)的結(jié)構(gòu)如下圖所示
定義生成器:輸入是 100 維,經(jīng)過三層隱藏層,輸出 784 維的向量(造假的圖片)
def create_generator():generator=Sequential()generator.add(Dense(units=256,input_dim=100))generator.add(LeakyReLU(0.2))generator.add(Dense(units=512))generator.add(LeakyReLU(0.2))generator.add(Dense(units=1024))generator.add(LeakyReLU(0.2))generator.add(Dense(units=784, activation='tanh'))generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())return generator g=create_generator() g.summary()定義判別器:判別器的輸入為真實(shí)圖片或者由生成器造出來的假圖片(784維),經(jīng)過三層隱藏層,輸出類別(1 維)
def create_discriminator():discriminator=Sequential()discriminator.add(Dense(units=1024,input_dim=784))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(units=512))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(units=256))discriminator.add(LeakyReLU(0.2))discriminator.add(Dense(units=1, activation='sigmoid'))discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())return discriminator d =create_discriminator() d.summary()定義生成對(duì)抗網(wǎng)絡(luò)
def create_gan(discriminator, generator):discriminator.trainable=False# 這是一個(gè)鏈?zhǔn)侥P?#xff1a;輸入經(jīng)過生成器、判別器得到輸出gan_input = Input(shape=(100,))x = generator(gan_input)gan_output= discriminator(x)gan= Model(inputs=gan_input, outputs=gan_output)gan.compile(loss='binary_crossentropy', optimizer='adam')return gan gan = create_gan(d,g) gan.summary()定義畫圖函數(shù)來可視化圖片的生成
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):noise= np.random.normal(loc=0, scale=1, size=[examples, 100])generated_images = generator.predict(noise)generated_images = generated_images.reshape(100,28,28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest')plt.axis('off')plt.tight_layout()plt.savefig('gan_generated_image %d.png' %epoch)生成對(duì)抗網(wǎng)絡(luò)的訓(xùn)練函數(shù)
def training(epochs=1, batch_size=128):#導(dǎo)入數(shù)據(jù)(X_train, y_train, X_test, y_test) = load_data()batch_count = X_train.shape[0] / batch_size# 定義生成器、判別器和GAN網(wǎng)絡(luò)generator= create_generator()discriminator= create_discriminator()gan = create_gan(discriminator, generator)for e in range(1,epochs+1 ):print("Epoch %d" %e)for _ in tqdm(range(int(batch_count))):#產(chǎn)生噪聲喂給生成器noise= np.random.normal(0,1, [batch_size, 100])# 產(chǎn)生假圖片generated_images = generator.predict(noise)# 一組隨機(jī)真圖片image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]# 真假圖片拼接 X= np.concatenate([image_batch, generated_images])# 生成數(shù)據(jù)和真實(shí)數(shù)據(jù)的標(biāo)簽y_dis=np.zeros(2*batch_size)y_dis[:batch_size]=0.9# 預(yù)訓(xùn)練,判別器區(qū)分真假discriminator.trainable=Truediscriminator.train_on_batch(X, y_dis)# 欺騙判別器 生成的圖片為真的圖片noise= np.random.normal(0,1, [batch_size, 100])y_gen = np.ones(batch_size)# GAN的訓(xùn)練過程中判別器的權(quán)重需要固定 discriminator.trainable=False# GAN的訓(xùn)練過程為交替“訓(xùn)練判別器”和“固定判別器權(quán)重訓(xùn)練鏈?zhǔn)侥P汀眊an.train_on_batch(noise, y_gen)if e == 1 or e % 50 == 0:# 畫圖 看一下生成器能生成什么plot_generated_images(e, generator) training(400,256)經(jīng)過訓(xùn)練后生成的圖片
一個(gè)epoch后生成器還是個(gè)小學(xué)生
100個(gè)epoch后生成器已經(jīng)有點(diǎn)樣子了
400個(gè)epoch后生成器可以出師了
是不是已經(jīng)學(xué)得像模像樣了,這樣就能夠利用噪聲通過生成器來生成以假亂真的圖片了。
總結(jié)
以上是生活随笔為你收集整理的plt生成固定的colormap_白话生成对抗网络GAN及代码实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 痛风的诊断
- 下一篇: 骨龄片怎么看骨龄几岁