Gumbel-Softmax Trick和Gumbel分布 附VAE讲解
轉自https://www.cnblogs.com/initial-h/p/9468974.html?寫的非常好,思路清晰,順帶連VAE trick也講了
之前看MADDPG論文的時候,作者提到在離散的信息交流環境中,使用了Gumbel-Softmax estimator。于是去搜了一下,發現該技巧應用甚廣,如深度學習中的各種GAN、強化學習中的A2C和MADDPG算法等等。只要涉及在離散分布上運用重參數技巧時(re-parameterization),都可以試試Gumbel-Softmax Trick。
??這篇文章是學習以下鏈接之后的個人理解,內容也基本出于此,需要深入理解的可以自取。
- The Humble Gumbel Distribution
- The Gumbel-Max Trick for Discrete Distributions
- The Gumbel-Softmax Trick for Inference of Discrete Variables
- 如何理解Gumbel-Max trick?
??這篇文章從直觀感覺講起,先講Gumbel-Softmax Trick用在哪里及如何運用,再編程感受Gumbel分布的效果,最后討論數學證明。
目錄
- 一、Gumbel-Softmax Trick用在哪里
- 問題來源
- Re-parameterization Trick
- Gumbel-Softmax Trick
- 二、Gumbel分布采樣效果
- 三、數學證明
一、Gumbel-Softmax Trick用在哪里
問題來源
??通常在強化學習中,如果動作空間是離散的,比如上、下、左、右四個動作,通常的做法是網絡輸出一個四維的one-hot向量(不考慮空動作),分別代表四個動作。比如[1,0,0,0]代表上,[0,1,0,0]代表下等等。而具體取哪個動作呢,就根據輸出的每個維度的大小,選擇值最大的作為輸出動作,即argmax(v)
。
??例如網絡輸出的四維向量為v=[?20,10,9.6,6.2]
,第二個維度取到最大值10,那么輸出的動作就是[0,1,0,0],也就是下,這和多類別的分類任務是一個道理。但是這種取法有個問題是不能計算梯度,也就不能更新網絡。通常的做法是加softmax函數,把向量歸一化,這樣既能計算梯度,同時值的大小還能表示概率的含義。softmax函數定義如下:
?
σ(zi)=ezi∑j=1Kezj
?
??那么將v=[?20,10,9.6,6.2]
通過softmax函數后有σ(v)=[0,0.591,0.396,0.013],這樣做不會改變動作或者說類別的選取,同時softmax傾向于讓最大值的概率顯著大于其他值,比如這里10和9.6經過softmax放縮之后變成了0.591和0.396,6.2對應的概率更是變成了0.013,這有利于把網絡訓成一個one-hot輸出的形式,這種方式在分類問題中是常用方法。
??但是這么做還有一個問題,這個表示概率的向量σ(v)=[0,0.591,0.396,0.013]并沒有真正顯示出概率的含義,因為一旦某個值最大,就選擇相應的動作或者分類。比如σ(v)=[0,0.591,0.396,0.013]和σ(v)=[0,0.9,0.1,0]
在類別選取的結果看來沒有任何差別,都是選擇第二個類別,但是從概率意義上講差別是巨大的。所以需要一種方法不僅選出動作,而且遵從概率的含義。
??很直接的方法是依概率采樣就完事了,比如直接用np.random.choice函數依照概率生成樣本值,這樣概率就有意義了。這樣做確實可以,但是又有一個問題冒了出來:這種方式怎么計算梯度?不能計算梯度怎么用BP的方式更新網絡?
??這時重參數(re-parameterization)技巧解決了這個問題,這里有詳盡的解釋,不過比較晦澀。簡單來說重參數技巧的一個用處是把采樣的步驟移出計算圖,這樣整個圖就可以計算梯度BP更新了。之前我一直在想分類任務直接softmax之后BP更新不就完事了嗎,為什么非得采樣。后來看了VAE和GAN之后明白,還有很多需要采樣訓練的任務。這里舉簡單的VAE(變分自編碼器)的例子說明需要采樣訓練的任務以及重參數技巧,詳細內容來自視頻和博客。
Re-parameterization Trick
??最原始的自編碼器通常長這樣:
??左右兩邊是端到端的出入輸出網絡,中間的綠色是提取的特征向量,這是一種直接從圖片提取特征的方式。
??而VAE長這樣:
??VAE的想法是不直接用網絡去提取特征向量,而是提取這張圖像的分布特征,也就把綠色的特征向量替換為分布的參數向量,比如說均值和標準差。然后需要decode圖像的時候,就從encode出來的分布中采樣得到特征向量樣本,用這個樣本去重建圖像,這時怎么計算梯度的問題就出現了。
??重參數技巧可以解決這個問題,它長下面這樣:
??假設圖中的x
和?表示VAE中的均值和標準差向量,它們是確定性的節點。而需要輸出的樣本z是帶有隨機性的節點,重參數就是把帶有隨機性的z變成確定性的節點,同時隨機性用另一個輸入節點?代替。例如,這里用正態分布采樣,原本從均值為x和標準差為?的正態分布N(x,?2)中采樣得到z。將其轉化成從標準正態分布N(0,1)中采樣得到?,再計算得到z=x+???。這樣一來,采樣的過程移出了計算圖,整張計算圖就可以計算梯度進行更新了,而新加的?
的輸入分支不做更新,只當成一個沒有權重變化的輸入。
??到這里,需要采樣訓練的任務實例以及重參數技巧基本有個概念了。
Gumbel-Softmax Trick
??VAE的例子是一個連續分布(正態分布)的重參數,離散分布的情況也一樣,首先需要可以采樣,使得離散的概率分布有意義而不是只取概率最大的值,其次需要可以計算梯度。那么怎么做到的,具體操作如下:
??對于n
維概率向量π,對π對應的離散隨機變量xπ
添加Gumbel噪聲,再取樣
?
xπ=argmax(log(πi)+Gi)
?
??其中,Gi
是獨立同分布的標準Gumbel分布的隨機變量,標準Gumbel分布的CDF為F(x)=e?e?x。
??這就是Gumbel-Max trick。可以看到由于這中間有一個argmax操作,這是不可導的,所以用softmax函數代替之,也就是Gumbel-Softmax Trick,而Gi可以通過Gumbel分布求逆從均勻分布生成,即Gi=?log(?log(Ui)),Ui~U(0,1)
,這樣就搞定了。
??具體實踐是這樣操作的,
- 對于網絡輸出的一個n
維向量v,生成n個服從均勻分布U(0,1)的獨立樣本?1,...,?n
- ?
- 通過Gi=?log(?log(?i))
- 計算得到Gi
- ?
- 對應相加得到新的值向量v′=[v1+G1,v2+G2,...,vn+Gn]
- ?
- 通過softmax函數
?
στ(v′i)=ev′i/τ∑j=1nev′j/τ
?
??計算概率大小得到最終的類別。其中τ
是溫度參數。
??直觀上感覺,對于強化學習來說,在選擇動作之前加一個擾動,相當于增加探索度,感覺上是合理的。對于深度學習的任務來說,添加隨機性去模擬分布的樣本生成,也是合情合理的。
二、Gumbel分布采樣效果
??為什么使用Gumbel分布生成隨機數,就能模擬離散概率分布的樣本呢?這部分使用代碼模擬來感受它的優越性。這部分例子和代碼來自這里。
??首先Gumbel分布的概率密度函數長這樣:
?
p(x)=1βe?z?e?z
?
??其中z=x?μβ
。
??Gumbel分布是一類極值分布,那么它表示什么含義呢?原鏈接舉了一個ice cream的例子,沒有get到點。這里舉一個類似的喝水的例子。
??比如你每天都會喝很多次水(比如100次),每次喝水的量也不一樣。假設每次喝水的量服從正態分布N(μ,σ2)
(其實也有點不合理,畢竟喝水的多少不能取為負值,不過無傷大雅能理解就好,假設均值為5),那么每天100次喝水里總會有一個最大值,這個最大值服從的分布就是Gumbel分布。實際上,只要是指數族分布,它的極值分布都服從Gumbel分布。那么上面這個例子的分布長什么樣子呢,作圖有
from scipy.optimize import curve_fit import numpy as np import matplotlib.pyplot as plt mean_hunger = 5 samples_per_day = 100 n_days = 10000 samples = np.random.normal(loc=mean_hunger, size=(n_days, samples_per_day)) daily_maxes = np.max(samples, axis=1)def gumbel_pdf(prob,loc,scale):z = (prob-loc)/scalereturn np.exp(-z-np.exp(-z))/scaledef plot_maxes(daily_maxes):probs,hungers,_=plt.hist(daily_maxes,density=True,bins=100)plt.xlabel('Volume')plt.ylabel('Probability of Volume being daily maximum')(loc,scale),_=curve_fit(gumbel_pdf,hungers[:-1],probs)#curve_fit用于曲線擬合#接受需要擬合的函數(函數的第一個參數是輸入,后面的是要擬合的函數的參數)、輸入數據、輸出數據#返回的是函數需要擬合的參數# https://blog.csdn.net/guduruyu/article/details/70313176plt.plot(hungers,gumbel_pdf(hungers,loc,scale))plt.figure() plot_maxes(daily_maxes)
??那么gumbel分布在離散分布的采樣中效果如何呢?可以作圖比較一下。先定義一個多項分布,作出真實的概率密度圖。再通過采樣的方式比較各種方法的效果。
??如下代碼定義了一個7類別的多項分布,其真實的密度函數如下圖
??首先我們直接根據真實的分布利用np.random.choice函數采樣對比效果
Original probabilities:??0.11?0.05?0.12?0.21?0.12?0.26?0.14
Estimated probabilities:?0.12?0.04?0.12?0.23?0.10?0.26?0.13
??效果意料之中的好。可以想到要是沒有不能求梯度這個問題,直接從原分布采樣是再好不過的。
??接著通過前述的方法添加Gumbel噪聲采樣,同時也添加正態分布和均勻分布的噪聲作對比
Original probabilities:??????0.11?0.05?0.12?0.21?0.12?0.26?0.14
Gumbel Estimated probabilities:?0.11?0.04?0.11?0.23?0.12?0.26?0.14
Normal Estimated probabilities:??0.08?0.02?0.11?0.26?0.11?0.29?0.12
Uniform Estimated probabilities:?0.00?0.00?0.00?0.32?0.01?0.63?0.03
??可以明顯看到Gumbel噪聲的采樣效果是最好的,正態分布其次,均勻分布最差。也就是說可以用Gumbel分布做Re-parameterization使得整個圖計算可導,同時樣本點最接近真實分布的樣本。
三、數學證明
??為什么添加Gumbel噪聲有如此效果,下面闡述問題并給出證明。
??假設有一個K
維的輸出向量,每個維度的值記為xk
,通過softmax函數可得,取到每個維度的概率為:
?
πk=exk∑Kk′=1exk′
?
??這是直接softmax得到的概率密度函數,如果換一種方式,對每個xk
添加獨立的標準Gumbel分布(尺度參數為1,位置參數為0)噪聲,并選擇值最大的維度作為輸出,得到的概率密度同樣為πk。
??下面給出Gumbel分布的概率密度函數和分布函數,并證明這件事情。
??尺度參數為1,位置參數為μ
的Gumbel分布的PDF為
?
f(z;μ)=e?(z?μ)?e?(z?μ)
?
??CDF為
?
F(z;μ)=e?e?(z?μ)
?
??假設第k
個Gumbel分布對應xk,加和得到隨機變量zk=xk+Gk,即相當于zk服從尺度參數為1,位置參數為μ=xk的Gumbel分布。要證明這樣取得的隨機變量zk與原隨機變量相同,只需證明取到zk的概率為πk。也就是zk比其他所有zk′(k′≠k)大的概率為πk
,即
?
P(zk≥zk′;?k′≠k|{xk′}Kk′=1)=πk
?
??關于zk
的條件累積概率分布函數為
?
P(zk≥zk′;?k′≠k|zk,{xk′}Kk′=1)=P(z1≤zk)P(z2≤zk)???P(zk?1≤zk)P(zk+1≤zk)???P(zK≤zk)
?
??即
?
P(zk≥zk′;?k′≠k|zk,{xk′}Kk′=1)=∏k′≠ke?e?(zk?xk′)
?
??對zk
求積分可得邊緣累積概率分布函數
?
P(zk≥zk′;?k′≠k|{xk′}Kk′=1)=∫P(zk≥zk′;?k′≠k|zk,{xk′}Kk′=1)?f(zk;xk)dzk
?
??帶入式子有
?
P(zk≥zk′;?k′≠k|{xk′}Kk′=1)=∫∏k′≠ke?e?(zk?xk′)?e?(zk?xk)?e?(zk?xk)dzk
?
??化簡有
?
P(zk≥zk′;?k′≠k|{xk′}Kk′=1)=∫∏k′≠ke?e?(zk?xk′)?e?(zk?xk)?e?(zk?xk)dzk=∫e?∑k′≠ke?(zk?xk′)?(zk?xk)?e?(zk?xk)dzk=∫e?∑Kk′=1e?(zk?xk′)?(zk?xk)dzk=∫e?(∑Kk′=1exk′)e?zk?zk+xkdzk=∫e?e?zk+ln(∑Kk′=1exk′)?zk+xkdzk=∫e?e?(zk?ln(∑Kk′=1exk′))?(zk?ln(∑Kk′=1exk′))?ln(∑Kk′=1exk′)+xkdzk=e?ln(∑Kk′=1exk′)+xk∫e?e?(zk?ln(∑Kk′=1exk′))?(zk?ln(∑Kk′=1exk′))dzk=exk∑Kk′=1exk′∫e?e?(zk?ln(∑Kk′=1exk′))?(zk?ln(∑Kk′=1exk′))dzk=exk∑Kk′=1exk′∫e?(zk?ln(∑Kk′=1exk′))?e?(zk?ln(∑Kk′=1exk′))dzk
?
??積分里面是μ=ln(∑Kk′=1exk′)
的Gumbel分布,所以整個積分為1。則有
?
P(zk≥zk′;?k′≠k|{xk′}Kk′=1)=exk∑Kk′=1exk′
?
??這和softmax的結果一致。
總結
以上是生活随笔為你收集整理的Gumbel-Softmax Trick和Gumbel分布 附VAE讲解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ECCV2020论文-稀疏性表示-Neu
- 下一篇: 【tensorflow】重置/清除计算图