变分自编码器:球面上的VAE(vMF-VAE)
?PaperWeekly 原創(chuàng) ·?作者|蘇劍林
單位|追一科技
研究方向|NLP、神經(jīng)網(wǎng)絡(luò)
在變分自編碼器:VAE + BN = 更好的 VAE 中,我們講到了 NLP 中訓(xùn)練 VAE 時(shí)常見(jiàn)的 KL 散度消失現(xiàn)象,并且提到了通過(guò) BN 來(lái)使得 KL 散度項(xiàng)有一個(gè)正的下界,從而保證 KL 散度項(xiàng)不會(huì)消失。事實(shí)上,早在 2018 年的時(shí)候,就有類(lèi)似思想的工作就被提出了,它們是通過(guò)在 VAE 中改用新的先驗(yàn)分布和后驗(yàn)分布,來(lái)使得 KL 散度項(xiàng)有一個(gè)正的下界。
該思路出現(xiàn)在 2018 年的兩篇相近的論文中,分別是《Hyperspherical Variational Auto-Encoders》[1] 和《Spherical Latent Spaces for Stable Variational Autoencoders》[2],它們都是用定義在超球面的 von Mises–Fisher(vMF)分布來(lái)構(gòu)建先后驗(yàn)分布。某種程度上來(lái)說(shuō),該分布比我們常用的高斯分布還更簡(jiǎn)單和有趣。
KL散度消失
我們知道,VAE 的訓(xùn)練目標(biāo)是:
其中第一項(xiàng)是重構(gòu)項(xiàng),第二項(xiàng)是 KL 散度項(xiàng),在變分自編碼器:原來(lái)是這么一回事中我們就說(shuō)過(guò),這兩項(xiàng)某種意義上是“對(duì)抗”的,KL 散度項(xiàng)的存在,會(huì)加大解碼器利用編碼信息的難度,如果 KL 散度項(xiàng)為 0,那么說(shuō)明解碼器完全沒(méi)有利用到編碼器的信息。
在 NLP 中,輸入和重構(gòu)的對(duì)象是句子,為了保證效果,解碼器一般用自回歸模型。然而,自回歸模型是非常強(qiáng)大的模型,強(qiáng)大到哪怕沒(méi)有輸入,也能完成訓(xùn)練(退化為無(wú)條件語(yǔ)言模型),而剛才我們說(shuō)了,KL 散度項(xiàng)會(huì)加大解碼器利用編碼信息的難度,所以解碼器干脆棄之不用,這就出現(xiàn)了 KL 散度消失現(xiàn)象。
早期比較常見(jiàn)的應(yīng)對(duì)方案是逐漸增加 KL 項(xiàng)的權(quán)重,以引導(dǎo)解碼器去利用編碼信息。現(xiàn)在比較流行的方案就是通過(guò)某些改動(dòng),直接讓 KL 散度項(xiàng)有一個(gè)正的下界。將先后驗(yàn)分布換為 vMF 分布,就是這種方案的經(jīng)典例子之一。
vMF分布
vMF 分布是定義在 d-1 維超球面的分布,其樣本空間為 ,概率密度函數(shù)則為:
其中 是預(yù)先給定的參數(shù)向量。不難想象,這是 上一個(gè)以 為中心的分布,歸一化因子寫(xiě)成 的形式,意味著它只依賴(lài)于 的模長(zhǎng),這是由于各向同性導(dǎo)致的。由于這個(gè)特性,vMF 分布更常見(jiàn)的記法是設(shè) ,從而:
這時(shí)候 就是 的夾角余弦,所以說(shuō),vMF 分布實(shí)際上就是以預(yù)先為度量的一種分布。由于我們經(jīng)常用余弦值來(lái)度量?jī)蓚€(gè)向量的相似度,因此基于 vMF 分布做出來(lái)的模型,通常更能滿(mǎn)足我們的這個(gè)需求。當(dāng) 的時(shí)候,vMF 分布是球面上的均勻分布。
從歸一化因子 的積分形式來(lái)看,它實(shí)際上也是 vMF 的母函數(shù),從而 vMF 的各階矩也可以通過(guò) 來(lái)表達(dá),比如一階矩為:
可以看到 在方向上跟 一致。 的精確形式可以算出來(lái),但比較復(fù)雜,而且很多時(shí)候我們也不需要精確知道這個(gè)歸一化因子,所以這里我們就不算了。
至于參數(shù) \kappa 的含義,或許設(shè) 我們更好理解,此時(shí) ,熟悉能量模型的同學(xué)都知道,這里的 就是溫度參數(shù),如果 越小( 越大),那么分布就越集中在 附近,反之則越分散(越接近球面上的均勻分布)。因此, 也被形象地稱(chēng)為“凝聚度(concentration)”參數(shù)。
從vMF采樣
對(duì)于 vMF 分布來(lái)說(shuō),需要解決的第一個(gè)難題是如何實(shí)現(xiàn)從它里邊采樣出具體的樣本來(lái)。尤其是如果我們要將它應(yīng)用到 VAE 中,那么這一步是至關(guān)重要的。
3.1 均勻分布
最簡(jiǎn)單是 的情形,也就是 d-1 維球面上的均勻分布,因?yàn)闃?biāo)準(zhǔn)正態(tài)分布本來(lái)就是各向同性的,其概率密度正比于 只依賴(lài)于模長(zhǎng),所以我們只需要從 d 為標(biāo)準(zhǔn)正態(tài)分布中采樣一個(gè) z,然后讓 就得到了球面上的均勻采樣結(jié)果。
3.2 特殊方向
接著,對(duì)于 的情形,我們記 ,首先考慮一種特殊的情況:。事實(shí)上,由于各向同性的原因,很多時(shí)候我們都只需要考慮這個(gè)特殊情況,然后就可以平行地推廣到一般情形。
此時(shí)概率密度正比于 ,然后我們轉(zhuǎn)換到球坐標(biāo)系:
那么:
這個(gè)分解表明,從該 vMF 分布中采樣,等價(jià)于先從概率密度正比于 的分布采樣一個(gè) ,然后從 d-2 維超球面上均勻采樣一個(gè) d-1 維向量 ,通過(guò)如下方式組合成最終采樣結(jié)果:
設(shè) ,那么:
所以我們主要研究從概率密度正比于 的分布中采樣。
然而,筆者所不理解的是,大多數(shù)涉及到 vMF 分布的論文,都采用了 1994 年的論文《Simulation of the von mises fisher distribution》[3] 提出的基于 beta 分布的拒絕采樣方案,整個(gè)采樣流程還是頗為復(fù)雜的。但現(xiàn)在都 2021 年了,對(duì)于一維分布的采樣,居然還需要拒絕采樣這么低效的方案?
事實(shí)上,對(duì)于任意一維分布 ,設(shè)它的累積概率函數(shù)為 ,那么 就是一個(gè)最方便通用的采樣方案。可能有讀者抗議說(shuō)“累積概率函數(shù)不好算呀”、“它的逆函數(shù)更不好算呀”,但是在用代碼實(shí)現(xiàn)采樣的時(shí)候,我們壓根就不需要知道 長(zhǎng)啥樣,只要直接數(shù)值計(jì)算就行了,參考實(shí)現(xiàn)如下:
import?numpy?as?npdef?sample_from_pw(size,?kappa,?dims,?epsilon=1e-7):x?=?np.arange(-1?+?epsilon,?1,?epsilon)y?=?kappa?*?x?+?np.log(1?-?x**2)?*?(dims?-?3)?/?2y?=?np.cumsum(np.exp(y?-?y.max()))y?=?y?/?y[-1]return?np.interp(np.random.random(size),?y,?x)這里的實(shí)現(xiàn)中,計(jì)算量最大的是變量 y 的計(jì)算,而一旦計(jì)算好之后,可以緩存下來(lái),之后只需要執(zhí)行最后一步來(lái)完成采樣,其速度是非常快的。這樣再怎么看,也比從 beta 分布中拒絕采樣要簡(jiǎn)單方便吧。順便說(shuō),實(shí)現(xiàn)上這里還用到了一個(gè)技巧,即先計(jì)算對(duì)數(shù)值,然后減去最大值,最后才算指數(shù),這樣可以防止溢出,哪怕 成千上萬(wàn),也可以成功計(jì)算。
3.3 一般情形
現(xiàn)在我們已經(jīng)實(shí)現(xiàn)了從 的 vMF 分布中采樣了,我們可以將采樣結(jié)果分解為:
同樣由于各向同性的原因,對(duì)于一般的 ,采樣結(jié)果依然具有同樣的形式:
對(duì)于 v 的采樣,關(guān)鍵之處是與 正交,這也不難實(shí)現(xiàn),先從標(biāo)準(zhǔn)正態(tài)分布中采樣一個(gè) d 維向量 z,然后保留與 正交的分量并歸一化即可:
vMF-VAE
至此,我們可謂是已經(jīng)完成了本篇文章最艱難的部分,剩下的構(gòu)建 vMF-VAE 可謂是水到渠成了。vMF-VAE 選用球面上的均勻分布()作為先驗(yàn)分布 ,并將后驗(yàn)分布選取為 vMF 分布:
簡(jiǎn)單起見(jiàn),我們將 設(shè)為超參數(shù)(也可以理解為通過(guò)人工而不是梯度下降來(lái)更新這個(gè)參數(shù)),這樣一來(lái), 的唯一參數(shù)來(lái)源就是 了。此時(shí)我們可以計(jì)算 KL 散度項(xiàng):
前面我們已經(jīng)討論過(guò),vMF 分布的均值方向跟 一致,模長(zhǎng)則只依賴(lài)于 d 和 ,所以代入上式后我們可以知道 KL 散度項(xiàng)只依賴(lài)于 d 和 ,當(dāng)這兩個(gè)參數(shù)被選定之后,那么它就是一個(gè)常數(shù)(根據(jù) KL 散度的性質(zhì),當(dāng) 時(shí),它必然大于 0),絕對(duì)不會(huì)出現(xiàn) KL 散度消失現(xiàn)象了。
那么現(xiàn)在就剩下重構(gòu)項(xiàng)了,我們需要用“重參數(shù)(Reparameterization)”來(lái)完成采樣并保留梯度,在前面我們已經(jīng)研究了vMF的采樣過(guò)程,所以也不難實(shí)現(xiàn),綜合的流程為:
這里的重構(gòu) loss 以 MSE 為例,如果是句子重構(gòu),那么換用交叉熵就好。其中 就是編碼器,而 就是解碼器,由于 KL 散度項(xiàng)為常數(shù),對(duì)優(yōu)化沒(méi)影響,所 以vMF-VAE 相比于普通的自編碼器,只是多了一項(xiàng)稍微有點(diǎn)復(fù)雜的重參數(shù)操作(以及人工調(diào)整 )而已,相比基于高斯分布的標(biāo)準(zhǔn) VAE 可謂簡(jiǎn)化了不少了。
此外,從該流程我們也可以看出,除了“簡(jiǎn)單起見(jiàn)”之外,不將 設(shè)為可訓(xùn)練還有一個(gè)主要原因,那就是 關(guān)系到 w 的采樣,而在w的采樣過(guò)程中要保留 的梯度是比較困難的。
參考實(shí)現(xiàn)
vMF-VAE 的實(shí)現(xiàn)難度主要是重參數(shù)部分,也就還是從 vMF 分布中采樣,而關(guān)鍵之處就是 w 的采樣。前面我們已經(jīng)給出了 w 的采樣的 numpy 實(shí)現(xiàn),但是在 tf 中未見(jiàn)類(lèi)似 np.interp 的函數(shù),因此不容易轉(zhuǎn)換為純 tf 的實(shí)現(xiàn)。當(dāng)然,如果是torch或者 tf2 這種動(dòng)態(tài)圖框架,直接跟 numpy 的代碼混合使用也無(wú)妨,但這里還是想構(gòu)造一種比較通用的方案。
其實(shí)也不難,由于 w 只是一個(gè)一維變量,每步訓(xùn)練只需要用到 batch_size 個(gè)采樣結(jié)果,所以我們完全可以事先用 numpy 函數(shù)采樣好足夠多(幾十萬(wàn))個(gè) w 存好,然后訓(xùn)練的時(shí)候直接從這批采樣好的結(jié)果隨機(jī)抽就行了,參考實(shí)現(xiàn)如下:
def?sampling(mu):"""vMF分布重參數(shù)操作"""dims?=?K.int_shape(mu)[-1]#?預(yù)先計(jì)算一批wepsilon?=?1e-7x?=?np.arange(-1?+?epsilon,?1,?epsilon)y?=?kappa?*?x?+?np.log(1?-?x**2)?*?(dims?-?3)?/?2y?=?np.cumsum(np.exp(y?-?y.max()))y?=?y?/?y[-1]W?=?K.constant(np.interp(np.random.random(10**6),?y,?x))#?實(shí)時(shí)采樣widxs?=?K.random_uniform(K.shape(mu[:,?:1]),?0,?10**6,?dtype='int32')w?=?K.gather(W,?idxs)#?實(shí)時(shí)采樣zeps?=?K.random_normal(K.shape(mu))nu?=?eps?-?K.sum(eps?*?mu,?axis=1,?keepdims=True)?*?munu?=?K.l2_normalize(nu)return?w?*?mu?+?(1?-?w**2)**0.5?*?nu一個(gè)基于 MNIST 的完整例子可見(jiàn):
https://github.com/bojone/vae/blob/master/vae_vmf_keras.py
至于 vMF-VAE 用于 NLP 的例子,我們?nèi)蘸笥袡C(jī)會(huì)再分享。本文主要還是以理論介紹和簡(jiǎn)單演示為主。
文章小結(jié)
本文介紹了基于 vMF 分布的 VAE 實(shí)現(xiàn),其主要難度在于 vMF 分布的采樣。總的來(lái)說(shuō),vMF 分布建立在余弦相似度度量之上,在某些方面的性質(zhì)更符合我們的直觀認(rèn)知,將其用于 VAE 中,能夠使得 KL 散度項(xiàng)為一個(gè)常數(shù),從而防止了 KL 散度消失現(xiàn)象,并且簡(jiǎn)化了 VAE 結(jié)構(gòu)。
參考文獻(xiàn)
[1] https://arxiv.org/abs/1804.00891
[2] https://arxiv.org/abs/1808.10805
[3] https://www.tandfonline.com/doi/abs/10.1080/03610919408813161
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類(lèi)優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
?????來(lái)稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來(lái)稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專(zhuān)欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結(jié)
以上是生活随笔為你收集整理的变分自编码器:球面上的VAE(vMF-VAE)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 清华大学刘知远教授新作,图神经网络最佳解
- 下一篇: 图像识别最新赛事!总奖金31万,一起组队