生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记
生成對抗網絡入門詳解及TensorFlow源碼實現–深度學習筆記
一、生成對抗網絡(GANs)
生成對抗網絡是一種生成模型(Generative Model),其背后最基本的思想就是從訓練庫里獲取很多的訓練樣本(Training Examples),從而學習這些訓練案例生成的概率分布。
GAN[Goodfellow Ian,GAN]啟發自博弈論中的二人零和博弈(two-player game),由[Goodfellow et al, NIPS 2014]開創性地提出。在二人零和博弈中,兩位博弈方的利益之和為零或一個常數,即一方有所得,另一方必有所失。GAN模型中的兩位博弈方分別由生成式模型(generative model)和判別式模型(discriminative model)充當。生成模型G捕捉樣本數據的分布,判別模型是一個二分類器,估計一個樣本來自于訓練數據(而非生成數據)的概率。G和D一般都是非線性映射函數,例如多層感知機、卷積神經網絡等。
二、生成對抗網絡的原理
1、生成對抗過程
GANs的方法,就是讓兩個網絡相互競爭“玩一個游戲”。
其中一個叫做生成器網絡( Generator Network),它不斷捕捉訓練庫里真實圖片的概率分布,將輸入的隨機噪聲(Random Noise)轉變成新的樣本(也就是假數據)。
另一個叫做判別器網絡(Discriminator Network),它可以同時觀察真實和假造的數據,判斷這個數據到底是不是真的。
所以整個訓練過程包含兩步,(在下圖里,判別器用 D 表示,生成器用 G 表示,真實數據庫樣本用 X 表示,噪聲用 Z 表示)。
第一步,只有判別器D參與。
我們把X樣本輸入可微函數D里運行,D輸出0-1之間的某個值,數值越大意味著X樣本是真實的可能性越大。在這個過程中,判別器D盡可能使輸出的值靠近1,因為這一階段的X樣本就是真實的圖片。
第二步,判別器D和生成器G都參與。
我們首先將噪聲數據Z喂給生成器G,G從原有真實圖像庫里學習概率分布,從而產生假的圖像樣本。然后,我們把假的數據交給判別器D。這一次,D將盡可能輸入數值0,這代表著輸入數據Z是假的。
所以這個過程中,判別器D相當于一個監督情況下的二分類器,數據要么歸為1,要么歸為0。
與傳統神經網絡訓練不一樣的且有趣的地方,就是我們訓練生成器的方法不同。生成器一心想要“騙過”判別器。使用博弈理論分析技術,我們可以證明這里面存在一種均衡。
2、數學原理
在訓練過程中,生成網絡G的目標就是盡量生成真實的圖片去欺騙判別網絡D。而D的目標就是盡量把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的“博弈過程”。
最后博弈的結果是什么?在最理想的狀態下,G可以生成足以“以假亂真”的圖片G(z)。對于D來說,它難以判定G生成的圖片究竟是不是真實的,因此D(G(z)) = 0.5。
這樣我們的目的就達成了:我們得到了一個生成式的模型G,它可以用來生成圖片。
以上只是大致說了一下GAN的核心原理,如何用數學語言描述呢?這里直接摘錄論文里的公式:
簡單分析一下這個公式:
? 整個式子由兩項構成。x表示真實圖片,z表示輸入G網絡的噪聲,而G(z)表示G網絡生成的圖片。
? D(x)表示D網絡判斷真實圖片是否真實的概率(因為x就是真實的,所以對于D來說,這個值越接近1越好)。而D(G(z))是D網絡判斷G生成的圖片的是否真實的概率。
? G的目的:上面提到過,D(G(z))是D網絡判斷G生成的圖片是否真實的概率,G應該希望自己生成的圖片“越接近真實越好”。也就是說,G希望D(G(z))盡可能得大,這時V(D, G)會變小。因此我們看到式子的最前面的記號是min_G。
? D的目的:D的能力越強,D(x)應該越大,D(G(x))應該越小。這時V(D,G)會變大。因此式子對于D來說是求最大(max_D)
三、GAN的優勢與缺陷
1、優勢
? 根據實際的結果,它們看上去可以比其它模型產生了更好的樣本(圖像更銳利、清晰)。
? 生成對抗式網絡框架能訓練任何一種生成器網絡(理論上-實踐中,用 REINFORCE 來訓練帶有離散輸出的生成網絡非常困難)。大部分其他的框架需要該生成器網絡有一些特定的函數形式,比如輸出層是高斯的。重要的是所有其他的框架需要生成器網絡遍布非零質量(non-zero mass)。生成對抗式網絡能學習可以僅在與數據接近的細流形(thin manifold)上生成點。
? 不需要設計遵循任何種類的因式分解的模型,任何生成器網絡和任何鑒別器都會有用。
? 無需利用馬爾科夫鏈反復采樣,無需在學習過程中進行推斷(Inference),回避了近似計算棘手的概率的難題。
2、存在的主要問題:
? 解決不收斂(non-convergence)的問題。
目前面臨的基本問題是:所有的理論都認為 GAN 應該在納什均衡(Nash equilibrium)上有卓越的表現,但梯度下降只有在凸函數的情況下才能保證實現納什均衡。當博弈雙方都由神經網絡表示時,在沒有實際達到均衡的情況下,讓它們永遠保持對自己策略的調整是可能的【OpenAI Ian Goodfellow的Quora】。
? 難以訓練:崩潰問題(collapse problem)
GAN模型被定義為極小極大問題,沒有損失函數,在訓練過程中很難區分是否正在取得進展。GAN的學習過程可能發生崩潰問題(collapse problem),生成器開始退化,總是生成同樣的樣本點,無法繼續學習。當生成模型崩潰時,判別模型也會對相似的樣本點指向相似的方向,訓練無法繼續。
? 無需預先建模,模型過于自由不可控。
與其他生成式模型相比,GAN這種競爭的方式不再要求一個假設的數據分布,即不需要formulate p(x),而是使用一種分布直接進行采樣sampling,從而真正達到理論上可以完全逼近真實數據,這也是GAN最大的優勢。然而,這種不需要預先建模的方法缺點是太過自由了,對于較大的圖片,較多的 pixel的情形,基于簡單 GAN 的方式就不太可控了。在GAN[Goodfellow Ian, Pouget-Abadie J] 中,每次學習參數的更新過程,被設為D更新k回,G才更新1回,也是出于類似的考慮。
四、DCGANs:深度卷積生成對抗網絡
DCGANs的基本架構就是使用幾層“反卷積”(Deconvolution)網絡。“反卷積”類似于一種反向卷積,這跟用反向傳播算法訓練監督的卷積神經網絡(CNN)是類似的操作。
CNN是將圖像的尺寸壓縮,變得越來越小,而反卷積是將初始輸入的小數據(噪聲)變得越來越大(但反卷積并不是CNN的逆向操作,這個下面會有詳解)。
如果你要把卷積核移動不止一個位置, 使用的卷積滑動步長更大,那么在反卷積的每一層,你所得到的圖像尺寸就會越大。
這個論文里另一個重要思想,就是在大部分網絡層中使用了“批量規范化”(batch normalization),這讓學習過程的速度更快且更穩定。另一個有趣的思想就是,如何處理生成器里的“池化層”(Pooling Layers),傳統CNN使用的池化層,往往取區域平均或最大來壓縮表征數據的尺寸。
在反卷積過程中,從代碼到最終生成圖片,表征數據變得越來越大,我們需要某個東西來逐漸擴大表征的尺寸。但最大值池化(max-pooling)過程并不可逆,所以DCGANs那篇論文里,并沒有采用池化的逆向操作,而只是讓“反卷積”的滑動步長設定為2或更大值,這一方法確實會讓表征尺寸按我們的需求增大。
DCGANs非常擅長生成特定Domain里的小圖片,這里是一些生成的“臥室”圖片樣本。這些圖片分辨率不是很高,但是你可以看到里面包含了門、窗戶、棉被、枕頭、床頭板、燈具等臥室常見物品。
五、生成對抗網絡應用
1、GANs的應用:“文本轉圖像”(Text to Image)
我們可以用GANs做很多應用,其中一種就是“文本轉圖像”(Text to Image)。在Scott Reed等人的一篇論文里(Generative Adversarial Text to Image Synthesis,鏈接 https://arxiv.org/abs/1605.05396),GANs根據輸入的信息產生了相關圖像,。
也就是說,生成器里輸入的不僅是隨機噪聲,還有一些特定的語句信息。所以判別器不僅要區分樣本是否是真實的,還要判定其是否與輸入的語句信息相符。
這里是他們的實驗結果,左上角的圖里有一些鳥,鳥的胸脯和鳥冠是是粉色,主羽和次羽是黑色,與所給語句描述的信息相符。
但是我們也看到,仍然存在“模型崩潰”問題,在右下角的黃白花里,確實產生了白色花瓣和黃色花蕊的花朵,但它們多少看起來是在同一個方向上映射出來的同一朵花,它們的花瓣數和尺寸幾乎相同。
所以,模型在輸出的多樣性方面還有些問題,這需要解決。但可喜的地方在于,輸入的語句信息都比較好的映射到產生的圖像樣本中。
2、有趣的GANs 圖像生成應用
在Indico和Facebook發布了他們自己的DCGAN代碼之后,很多人開發出他們自己的、有趣的GANs應用。有的生成新的花朵圖像,還有新動漫角色。我個人最喜歡的,是一個能生成新品種精靈寶可夢的應用。
在一個 Youtube 視頻,你會看到學習過程:生成器被迫去學習怎么騙過判別器,圖像逐漸變得更真實。有些生成的寶可夢,雖然它們是全新的品種,看上去就像真的一樣。這些圖像的真實感并沒有一些專業學術論文里面的那么強,但對于現在的生成模型來說,不經過任何額外處理就能得到這樣的結果,已經非常不錯了。
3、超分辨率
一篇最近發表的論文,描述怎么利用GANs進行超分辨率重建(Super-Resolution)。我不確定這能否在本視頻中體現出來,因為視頻清晰度的限制。基本思想是,你可以在有條件的GANs里,輸入低分辨率圖像,然后輸出高分版本。使用生成模型的原因在于,這是一個約束不足(underconstrained)的問題:對于任何一個低分辨率圖像,有無數種可能的高分辨率版本。相比其他生成模型,GANs特別適用超分辨率應用。因為GANs的專長就是創建極有真實感的樣本。它們并不特別擅長做概率函數密度的估測,但在超分辨率應用中,我們最終關心的是輸出高分圖像,而不是概率分布。
(從左到右分別為:圖1、2、3、4)
上面展示的四幅圖像中,最左邊的是原始高分圖像(圖1),剩下的其余三張圖片都是通過對圖片的降采樣(Down Sample)生成的。我們把降采樣得到的圖片用不同的方法進行放大,以期得到跟原始圖像同樣的品質。
這些方法有很多種,比如我們用雙三次插值(Bicubic Interpolation)方式,生成的圖像(圖2)看起來很模糊,且對比度很低。另一個深度學習方法SRResNet(圖3)的效果更好,圖片已經干凈了很多。但若采用GANs重建的圖片(圖4),有著比其它兩種方式更低的信噪比。雖然我們直觀上覺得圖3看起來更清晰,事實上它的信噪比更高一些。GANs在量化矩陣(Quantitative Matrix)和人眼清晰度感知兩方面,都有很好的表現。
六、TensorFlow源碼(生成手寫字體)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np from skimage.io import imsave import os import shutilimg_height = 28 img_width = 28 img_size = img_height * img_widthto_train = True to_restore = False output_path = "output"# 總迭代次數500 max_epoch = 500h1_size = 150 h2_size = 300 z_size = 100 batch_size = 256# generate (model 1) def build_generator(z_prior):w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)h3 = tf.matmul(h2, w3) + b3x_generate = tf.nn.tanh(h3)g_params = [w1, b1, w2, b2, w3, b3]return x_generate, g_params# discriminator (model 2) def build_discriminator(x_data, x_generated, keep_prob):# tf.concatx_in = tf.concat([x_data, x_generated],0)w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)h3 = tf.matmul(h2, w3) + b3y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))d_params = [w1, b1, w2, b2, w3, b3]return y_data, y_generated, d_params# def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5img_h, img_w = batch_res.shape[1], batch_res.shape[2]grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)for i, res in enumerate(batch_res):if i >= grid_size[0] * grid_size[1]:breakimg = (res) * 255img = img.astype(np.uint8)row = (i // grid_size[0]) * (img_h + grid_pad)col = (i % grid_size[1]) * (img_w + grid_pad)img_grid[row:row + img_h, col:col + img_w] = imgimsave(fname, img_grid)def train():# load data(mnist手寫數據集)mnist = input_data.read_data_sets('MNIST_data', one_hot=True)x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")keep_prob = tf.placeholder(tf.float32, name="keep_prob")global_step = tf.Variable(0, name="global_step", trainable=False)# 創建生成模型x_generated, g_params = build_generator(z_prior)# 創建判別模型y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)# 損失函數的設置d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))g_loss = - tf.log(y_generated)optimizer = tf.train.AdamOptimizer(0.0001)# 兩個模型的優化函數d_trainer = optimizer.minimize(d_loss, var_list=d_params)g_trainer = optimizer.minimize(g_loss, var_list=g_params)init = tf.initialize_all_variables()saver = tf.train.Saver()# 啟動默認圖sess = tf.Session()# 初始化sess.run(init)if to_restore:chkpt_fname = tf.train.latest_checkpoint(output_path)saver.restore(sess, chkpt_fname)else:if os.path.exists(output_path):shutil.rmtree(output_path)os.mkdir(output_path)z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)steps = 60000 / batch_sizefor i in range(sess.run(global_step), max_epoch):for j in np.arange(steps): # for j in range(steps):print("epoch:%s, iter:%s" % (i, j))# 每一步迭代,我們都會加載256個訓練樣本,然后執行一次train_stepx_value, _ = mnist.train.next_batch(batch_size)x_value = 2 * x_value.astype(np.float32) - 1z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)# 執行生成sess.run(d_trainer,feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})# 執行判別if j % 1 == 0:sess.run(g_trainer,feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})show_result(x_gen_val, "output/sample{0}.jpg".format(i))z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))sess.run(tf.assign(global_step, i + 1))saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)def test():z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")x_generated, _ = build_generator(z_prior)chkpt_fname = tf.train.latest_checkpoint(output_path)init = tf.initialize_all_variables()sess = tf.Session()saver = tf.train.Saver()sess.run(init)saver.restore(sess, chkpt_fname)z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})show_result(x_gen_val, "output/test_result.jpg")if __name__ == '__main__':if to_train:train()else:test()參考文獻
http://blog.csdn.net/solomon1558/article/details/52549409
http://www.leiphone.com/news/201612/eAOGpvFl60EgFSwS.html
http://www.itwendao.com/article/detail/403491.html
總結
以上是生活随笔為你收集整理的生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tcpdump命令使用总结
- 下一篇: Ubuntu服务器上搭建solo个人博客