简述生成式对抗网络 GAN
【轉載請注明出處】chenrudan.github.io
本文主要闡述了對生成式對抗網絡的理解,首先談到了什么是對抗樣本,以及它與對抗網絡的關系,然后解釋了對抗網絡的每個組成部分,再結合算法流程和代碼實現來解釋具體是如何實現并執行這個算法的,最后給出一個基于對抗網絡改寫的去噪網絡運行的結果,效果雖然挺差的,但是有些地方還是挺有意思的。
- 1. 對抗樣本
- 2. 生成式對抗網絡GAN
- 3. 代碼解釋
- 4. 運行實例
- 5. 小結
- 6. 引用
1. 對抗樣本(adversarial examples)
14年的時候Szegedy在研究神經網絡的性質時,發現針對一個已經訓練好的分類模型,將訓練集中樣本做一些細微的改變會導致模型給出一個錯誤的分類結果,這種雖然發生擾動但是人眼可能識別不出來,并且會導致誤分類的樣本被稱為對抗樣本,他們利用這樣的樣本發明了對抗訓練(adversarial training),模型既訓練正常的樣本也訓練這種自己造的對抗樣本,從而改進模型的泛化能力[1]。如下圖所示,在未加擾動之前,模型認為輸入圖片有57.7%的概率為熊貓,但是加了之后,人眼看著好像沒有發生改變,但是模型卻認為有99.3%的可能是長臂猿。
圖1 對抗樣本的產生(圖來源[2])這個問題乍一看很像過擬合,在Goodfellow在15年[3]提到了其實模型欠擬合也能導致對抗樣本,因為從現象上來說是輸入發生了一定程度的改變就導致了輸出的不正確,例如下圖一,上下分別是過擬合和欠擬合導致的對抗樣本,其中綠色的o和x代表訓練集,紅色的o和x即對抗樣本,明顯可以看到欠擬合的情況下輸入發生改變也會導致分類不正確(其實這里我覺得有點奇怪,因為圖中所描述的對抗樣本不一定就是跟原始樣本是同分布的,感覺是人為造的一個東西,而不是真實數據的反饋)。在[1]中作者覺得這種現象可能是因為神經網絡的非線性和過擬合導致的,但Goodfellow卻給出了更為準確的解釋,即對抗樣本誤分類是因為模型的線性性質導致的,說白了就是因為wTxwTx存在點乘,當xx的每一個維度上都發生改變x?=x+ηx~=x+η,就會累加起來在點乘的結果上附加上一個比較大的和wTx?=wTx+wTηwTx~=wTx+wTη,而這個值可能就改變了預測結果。例如[4]中給出的一個例子,假設現在用邏輯回歸做二分類,輸入向量是x=[2,?1,3,?2,2,2,1,?4,5,1]x=[2,?1,3,?2,2,2,1,?4,5,1],權重向量是w=[?1,?1,1,?1,1,?1,1,1,?1,1]w=[?1,?1,1,?1,1,?1,1,1,?1,1],點乘結果是-3,類預測為1的概率為0.0474,假如將輸入變為xad=x+0.5w=[1.5,?1.5,3.5,?2.5,2.5,1.5,1.5,?3.5,4.5,1.5]xad=x+0.5w=[1.5,?1.5,3.5,?2.5,2.5,1.5,1.5,?3.5,4.5,1.5],那么類預測為1的概率就變成了0.88,就因為輸入在每個維度上的改變,導致了前后的結果不一致。
圖2 過/欠擬合導致對抗樣本(圖來源[3])如果認為對抗樣本是因為模型的線性性質導致的,那么是否能夠構造出一個方法來生成對抗樣本,即如何在輸入上加擾動,Goodfellow給出了一種構造方法fast gradient sign method[2],其中JJ是損失函數,再對輸入xx求導,θθ是模型參數,??是一個非常小的實數。圖1中就是?=0.007?=0.007。
η=?sign(▽xJ(θ,x,y))(1)η=?sign(▽xJ(θ,x,y))(1)這個構造方法在[4]中有比較多的實例,這里截取了兩個例子來說明,用imagenet圖片縮放到64*64來訓練一個一層的感知機,輸入是64*64*3,輸出是1000,權重是64*64*3*1000,訓練好之后取權重矩陣對應某個輸出類別的一行64*64*3,將這行還原成64*64圖片顯示為下圖中第二列,再用公式1的方法從第一列的原始圖片中算出第三列的對抗樣本,可以看到第一行從預測為狐貍變成了預測為金魚,第二行變成了預測為校車。
圖3 構造對抗樣本(圖來源[4])實際上不是只有純線性模型才會出現這種情況,卷積網絡的卷積其實就是線性操作,因此也有預測不穩定的情況,relu/maxout甚至sigmoid的中間部分其實也算是線性操作。因為可以自己構造對抗樣本,那么就能應用這個性質來訓練模型,讓模型泛化能力更強。因而[2]給定了一種新的目標函數也就是下面的式子,相當于對輸入加入一些干擾,并且也通過實驗結果證實了訓練出來的模型更加能夠抵抗對抗樣本的影響。
J?(θ,x,y)=αJ(θ,x,y)+(1?α)J(θ,x+?sign(▽xJ(θ,x,y)))(2)J~(θ,x,y)=αJ(θ,x,y)+(1?α)J(θ,x+?sign(▽xJ(θ,x,y)))(2)對抗樣本跟生成式對抗網絡沒有直接的關系,對抗網絡是想學樣本的內在表達從而能夠生成新的樣本,但是有對抗樣本的存在在一定程度上說明了模型并沒有學習到數據的一些內部表達或者分布,而可能是學習到一些特定的模式足夠完成分類或者回歸的目標而已。公式1的構造方法只是在梯度方向上做了一點非常小的變化,但是模型就無法正確的分類。此外還觀察到一個現象,用不同結構的多個分類器來學習相同數據,往往會將相同的對抗樣本誤分到相同的類中,這個現象看上去是所有的分類器都被相同的變化所干擾了。
2. 生成式對抗網絡GAN
14年Goodfellow提出Generative adversarial nets即生成式對抗網絡[5],它要解決的問題是如何從訓練樣本中學習出新樣本,訓練樣本是圖片就生成新圖片,訓練樣本是文章就輸出新文章等等。如果能夠知道訓練樣本的分布p(x)p(x),那么就可以在分布中隨機采樣得到新樣本,大部分的生成式模型都采用這種思路,GAN則是在學習從隨機變量zz到訓練樣本xx的映射關系,其中隨機變量可以選擇服從正太分布,那么就能得到一個由多層感知機組成的生成網絡G(z;θg)G(z;θg),網絡的輸入是一個一維的隨機變量,輸出是一張圖片。如何讓輸出的偽造圖片看起來像訓練樣本,Goodfellow采用了這樣一種方法,在生成網絡后面接上一個多層感知機組成的判別網絡D(x;θd)D(x;θd),這個網絡的輸入是隨機選擇一張真實樣本或者生成網絡的輸出,輸出是輸入圖片來自于真實樣本pdatapdata或者生成網絡pgpg的概率,當判別網絡能夠很好的分辨出輸入是不是真實樣本時,也能通過梯度的方式說明什么樣的輸入更加像真實樣本,從而通過這個信息來調整生成網絡。從而GG需要盡可能的讓自己的輸出像真實樣本,而DD則盡可能的將不是真實樣本的情況分辨出來。下圖左邊是GAN算法的概率解釋,右邊是模型構成。
圖4 GAN算法框圖(圖來源[6])GAN的優化是一個極小極大博弈問題,最終的目的是generator的輸出給discriminator時很難判斷是真實or偽造的,即極大化DD的判斷能力,極小化將GG的輸出判斷為偽造的概率,公式如下。論文[5]中將下面式子轉化成了Jensen-shannon散度的形式證明了僅當pg=pdatapg=pdata時能得到全局最小值,即生成網絡能完全的還原出真實樣本分布,并且證明了下式能夠收斂。(算法流程論文講的很清楚,這里就不說了,后面結合代碼一起解釋。)
minGmaxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(z)))](3)minGmaxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(z)))](3)以上是關于最基本GAN的介紹,最開始我看了論文后產生了幾個疑問,1.為什么不能直接學習GG,即直接學習一個zz到一個xx?2.GG具體是如何訓練的?3.在訓練的時候zz跟xx是一一對應關系嗎?在對代碼理解之后大概能夠給出一個解釋。
3. 代碼解釋
這部分主要結合tensorflow實現代碼[7]、算法流程和下面的變化圖[5]解釋一下具體如何使用DCGAN來生成手寫體圖片。
下圖中黑色虛線是真實數據的高斯分布,綠色的線是生成網絡學習到的偽造分布,藍色的線是判別網絡判定為真實圖片的概率,標x的橫線代表服從高斯分布x的采樣空間,標z的橫線代表服從均勻分布z的采樣空間??梢钥闯?span id="ze8trgl8bvbq" class="MathJax" id="MathJax-Element-34-Frame" tabindex="0" style="display:inline; line-height:normal; text-align:left; word-spacing:normal; word-wrap:normal; white-space:nowrap; float:none; direction:ltr; max-width:none; max-height:none; min-width:0px; min-height:0px; border:0px; padding:0px; margin:0px; position:relative">GG就是學習了從z的空間到x的空間的映射關系。
圖5 GAN運行時各個概率分布圖(圖來源[5])a.起始情況
DD是一個卷積神經網絡,變量名是D,其中一層構造方式如下。
| 12345678 | w = tf.get_variable('w', [4, 4, c_dim, num_filter], initializer=tf.truncated_normal_initializer(stddev=stddev))dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME')biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0))bias = tf.nn.bias_add(dconv, biases)dconv1 = tf.maximum(bias, leak*bias)... |
GG是一個逆卷積神經網絡,變量名是G,其中一層構造方式如下。
| 12345678910 | w = tf.get_variable('w', [4, 4, num_filter, num_filter*2], initializer=tf.random_normal_initializer(stddev=stddev))deconv = tf.nn.conv2d_transpose(gconv2, w, output_shape=[batch_size, s2, s2, num_filter], strides=[1, 2, 2, 1])biases = tf.get_variable('biases', [num_filter],initializer=tf.constant_initializer(0.0))bias = tf.nn.bias_add(deconv, biases)deconv1 = tf.nn.relu(bias, name=scope.name)... |
GG的網絡輸入為一個zdimzdim維服從-1~1均勻分布的隨機變量,這里取的是100.
| 12 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32) |
DD的網絡輸入是一個batch的64*64的圖片,既可以是手寫體數據也可以是GG的一個batch的輸出。
這個過程可以參考上圖的a狀態,判別曲線處于不夠穩定的狀態,兩個網絡都還沒訓練好。
b.訓練判別網絡
判別網絡的損失函數由兩部分組成,一部分是真實數據判別為1的損失,一部分是GG的輸出self.G判別為0的損失,需要優化的損失函數定義如下。
| 123456789 | self.G = self.generator(self.z)self.D, self.D_logits = self.discriminator(self.images)self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))self.d_loss = self.d_loss_real + self.d_loss_fake |
然后將一個batch的真實數據batch_images,和隨機變量batch_z當做輸入,執行session更新DD的參數。
| 123456 | # update discriminator on reald_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars)...out1 = sess.run([d_optim], feed_dict={real_images: batch_images, noise_images: batch_z}) |
這一步可以對比圖b,判別曲線漸漸趨于平穩。
c.訓練生成網絡
生成網絡并沒有一個獨立的目標函數,它更新網絡的梯度來源是判別網絡對偽造圖片求的梯度,并且是在設定偽造圖片的label是1的情況下,保持判別網絡不變,那么判別網絡對偽造圖片的梯度就是向著真實圖片變化的方向。
| 12 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) |
然后用同樣的隨機變量batch_z當做輸入更新
| 1234 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) .minimize(self.g_loss, var_list=self.g_vars)...out2 = sess.run([g_optim], feed_dict={noise_images:batch_z}) |
這一步可以對比圖c,pgpg的曲線在漸漸的向真實分布靠攏。而網絡訓練完成之后可以看到pgpg的曲線與pdatapdata重疊在了一起,并且此時判別網絡已經難以區分真實與偽造,因此取值就固定在了1212。
因而針對我之前的問題,2已經有了答案,針對1,為什么不能直接學習GG?這是因為無法確定zz與xx的一一對應關系,就像下圖,兩種對應關系,如果要肯定誰是對誰是錯,那么就得加入一些先驗信息,甚至是直接對真實樣本的估計,那么跟其他的方法不就一樣了么。而問題3,在訓練的時候zz跟xx是一一對應關系嗎?我開始考慮這個問題是因為不清楚是不是一個100維的noise變量就對應著一個手寫體變量圖片,但是現在考慮一下就應該明白在訓練的層面上不是一一對應的,甚至兩者在訓練DD的時候都是分開的,只是可能在分布中會存在這樣一種對應關系而已。
圖6 z與x映射圖(圖來源[8])4. 運行實例
這里本來想用GAN來跑一個去噪的網絡,基于[7]的代碼改了一下輸入,從一個100維的noise向量變成了一張輸入圖片,同時將generator網絡的前面部分變成了卷積網絡,再連上原來的逆卷積,就成了一個去噪網絡,這里我沒太多時間來細致的調節網絡層數、參數等,就隨便試了一下,效果也不是特別的好。代碼在[9]中。首先我通過read_stl10.py對stl10數據集加上了均值為0方差為50的高斯噪聲,前后對比如下。
圖7 增加高斯噪聲前后對比然后執行對抗網絡,會得到如下的去噪效果,從左到右分別是加了噪聲的輸入圖片,對應的generator網絡的輸出圖片,已經對應的干凈圖片,效果不是特別好,輪廓倒是能學到一點,但是這個顏色卻沒學到。
圖8 去噪對比5. 小結
剛開始搜資料的時候發現了對抗樣本,以為跟對抗網絡有關系,就看了一下,后來看Goodfellow的論文時發現其實沒什么關系,但是還是寫了一些內容,因為這個東西的存在還是值得了解的,而對抗網絡這個想法真的太贊了,它將一個無監督問題轉化為有監督,更加像一種learn的方式來學習數據應該是如何產生,而不是find的方式來找某些特征,但是訓練也是一個難題,從我的經驗來看,特別容易過擬合,而且確實有一種對抗的感覺在里面,因為generator的輸入時好時壞,總的來說是個很棒的算法,非常期待接下來的研究。
6. 引用
[1]?Intriguing properties of neural networks
[2]?EXPLAINING AND HARNESSING ADVERSARIAL EXAMPLES
[3]?Adversarial Examples
[4]?Breaking Linear Classifiers on ImageNet
[5]?Generative Adversarial Nets
[6]?Quick introduction to GANs
[7]?carpedm20/DCGAN-tensorflow
[8]?Generative Adversarial Nets in TensorFlow (Part I)
[9]?chenrudan/deep-learning/denoise_dcgan/
總結
以上是生活随笔為你收集整理的简述生成式对抗网络 GAN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Part 2 – Deep analys
- 下一篇: 【David Silver强化学习公开课