GAN(对抗生成网络)原理及数学推导
本文主要涉及GAN網絡的直觀理解和其背后的數學原理。
參考課程:
計算機視覺與深度學習 北京郵電大學 魯鵬
概述
在所有生成模型中,GAN屬于 “密度函數未知,直接硬train” 的那一類,和密度函數可定義的PixelRNN/CNN以及變分自編碼器VAE有本質區別。
假設現在我們想做人臉的生成任務。我們希望能找到人臉圖像的真實分布,這樣直接在這個分布上隨便取點,得到的都是人臉的圖像。但是分布非常復雜,且無法知道。
所以,我們考慮用一個簡單的分布和一個映射,將這個簡單的分布映射到真實的分布。我們使用神經網絡來學習這個映射的過程。
GAN的直觀理解
目標函數
GAN網絡的設計思路類似玩家博弈的過程,其主要優化的目標為:
符號說明:P_data是真實數據的分布,P(z)是噪聲分布(可以是均勻分布、高斯分布等), theta(g)是生成器的參數, theta(d)是判別器的參數。
公式中
x^=Gθg(z)\hat{x} = G_\theta{_g}(z) x^=Gθ?g?(z)
表示生成器生成的樣本,而
Dθd(x)D_\theta{_d}(x) Dθ?d?(x)
輸出一個0-1之間的數,表示判別器對輸入的判斷,1表示是真實數據,0表示是生成的數據。
我們先看內側max,調整theta_d(判別器的參數),使得后面式子最大。對于真實樣本(Ex~data那一項),希望判別器生成1;對假樣本x_hat,希望D_theta(d)把他輸出成0,這樣1減去之后最大。
【注意!!在討論max的時候調整d,此時生成器g的參數是固定的!!反之亦然。】
再來看min的時候,學的是生成器g的參數。此時,前面那一項無所謂(與g無關)。此時希望
Dθd(Gθg(z))D_\theta{_d}(G_\theta{_g(z)}) Dθ?d?(Gθ?g?(z))
趨近于1,此時theta_d不變,我們 希望生成的樣本被判別器判斷成1. 也就是固定D的前提下,讓G盡量欺騙D。
theta_d想讓表達式越大越好,所以是梯度上升。
因為判別器最后輸出是(0, 1)的值,所以最后一層是一個sigmoid. 想讓正樣本越大越好,負樣本越小越好,可以用一個二分類交叉熵損失(BCE)監督。【這里體會到:“似然越大越好” 等價于 “交叉熵損失越小越好”,因為那個max里面是一個概率/似然。下文會詳細說明。】
但是實際這樣梯度會出現問題——
這樣訓練的效果很差。因為剛開始生成的爛,梯度還小,學不動;后來生成的好,不太需要變化了,梯度反而很大。
G+D是一個網絡,D在G后面。優化的時候,是凍結一個,訓練另一個。而梯度回傳會首先經過D,再回傳到G。
實際實現時,我們會將min換成max,使得梯度問題得以解決。
用下面這張圖總結下GAN網絡的學習過程。
【這里(a)表示的意思是:一開始,判別器沒有學好,無法區分真實和生成的分布。】
數學推導
JS散度
在開始之前,先給出JS散度的定義。
JS散度度量了兩個概率分布的相似度,是基于KL散度的變體,解決了KL散度非對稱的問題。一般地,JS散度是對稱的,其取值是0到1之間。定義如下:
JS散度是可以理解為“距離”的,因為是對稱的,而KL散度不行,只能說是一種“相似程度”。
極大似然估計 VS KL散度
一般的,我們要選取一個theta,使得似然值最大。
先放結論:
最大化似然 = 最小化KL散度。
【這是一個貫穿機器學習過程的關鍵理解】。
以下是每一步化簡的過程:
回到GAN
Z是噪聲服從的分布,這里可以取均勻分布或高斯分布。我們使用神經網絡建模,學習了一個G,將Z映射到了一個密度分布P_G.
我們希望調整生成器的參數,使得G的密度分布與真實數據的密度分布接近(其中的Div表示散度,不一定是KL散度)。
但是,P_G是神經網絡擬合的,Pdata是未知的,表達式我們根本寫不出來,怎么優化?
以下是解決方法。
1、雖然我們不知道這兩個分布的具體表達式,但是我們可以從中獲取樣本!
2、接著,我們把GAN的目標式子中的z統一換成G(因為樣本是從G的分布里取出來的嘛)。
V(G,D)=Ex?Pdata[logD(x)]+Ex?PG[log(1?D(x))]V(G, D) = E_{x - P_{data}}[logD(x)] + E_{x - P_{G}}[log(1-D(x))] V(G,D)=Ex?Pdata??[logD(x)]+Ex?PG??[log(1?D(x))]
3、與上面類似,我們先考慮優化判別器(對應max的部分)。
這里先給出結論:
最大化maxV(D, G)等價于度量P_data和P_G之間的JS散度!
我們不是沒法度量Div(Pg,Pdata)嘛?現在找到度量方式了!
只需要最大化V(D, G),便可以度量Pg和Pdata之間的JS散度。
先忽略結論的證明,我們繞開了Pdata和Pg數學表達式無法獲得的問題,解決了度量兩個密度分布的方法。因為maxV的時候,只需要把訓練樣本輸入到神經網絡中即可訓練theta_G!
換言之,訓練神經網絡,實際就是在度量Pdata和Pg之間的JS散度。
直觀理解
關于結論的證明,先從直觀的角度來進行。
如果生成的和樣本很像,判別器判別很困難,V(G,D)小【因為判斷困難,真實數據得不到1,生成的假數據也得不到0,V值自然不高】;反之V(G,D)大 ==》 這不就類似在刻畫“散度”嘛?
越好分,值越小,證明他們的距離越小;越難分,值越大,證明他們距離越大!
理論推導
這里用了一個結論:如果想要最大化積分,那么如果對于每個x,f(x)都是最大的,那積分出來的結果也最大,這樣我們就去掉了積分符號。
在x給定的情況下,我們要找到最大的D’,對D求導即可。
將 前面求出的D’帶入V(G, D), 并人為加入1/2的因子,朝著JS散度的方向化簡。
最后我們便會發現,把最優參數帶入后,此時的V(G, D)取到max值,也就是在度量Pdata與PG的JS散度。所以,判別器的輸出值就代表了Pdata和Pg的差異!判別器輸出值越大,表示Pdata和Pg分的越開;輸出值越小,表示他們離得越近。
手動推導及每一步化簡的過程:
再看目標式
我們已經證明了,最大化V(D, G)就等價于計算了JS散度。所以對于上面的3個G,在固定G的情況下,我們可以得到D’為圖中紅色豎線的值(這時V最大)。
而生成器的優化目標為:找到一個最優參數G,使得生成的P_G的概率分布和真實數據的概率分布之間的差異越小越好。
假設我們現在G的候選參數就這三個,那就是從三個值里選擇最小的值,G3就是最后學到的結果(因為他的V最小,而V是JS散度的刻畫,生成器希望差異小)。
“判別器,最大化V(G, D)”可以理解為在藍色的線上找最大值;
“生成器,最小化Div”可以理解為從所有紅線中找出最小值。
而關鍵的橋梁“距離”,就是通過maxV(G, D)實現的。
實際操作
實際做的時候,可以用BCE做損失函數監督。【再次體現最大化似然等價最小化交叉熵】
Summary
但是其實GAN還是有很多問題的,這也是為什么后來出現了WGAN等,這個在這里就按下不表了。
一個小問題
在訓練的過程中,我們往往對于判別器訓練多次,而生成器只訓練一次。這是為什么呢?
一個直觀的理解是“判別器如果訓練不好,那生成器訓練多次也沒什么用”,但這么理解只是流于表面。
可以從上面數學推導的理論來考慮。
優化D是為了是discriminator對應的目標函數最大,也就是在整個數據分布上,盡力做到正確區分,這個需要多輪過程做到,且優化D不會改變Pdata和Pg;但是對于generator,一次優化后,很可能此時此數據分布上的discriminator所最優的區分能力并不適合你已經改變之后的generator,導致不符合理論上的推導(也就是我們要最小化的JSD)【只有在固定D局部優化G時,才能看成近似優化兩個分布的JS散度。 優化G之后,Pg已經變了,此時如果D不動而還訓練G,就不符合理論了】。
總結
以上是生活随笔為你收集整理的GAN(对抗生成网络)原理及数学推导的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: matlab2017中工具箱,【2017
- 下一篇: 高恪一键管控之封杀随身wifi与电视盒子