变分自编码器系列:VAE + BN = 更好的VAE
?PaperWeekly 原創(chuàng) ·?作者|蘇劍林
單位|追一科技
研究方向|NLP、神經(jīng)網(wǎng)絡(luò)
本文我們繼續(xù)之前的變分自編碼器系列,分析一下如何防止 NLP 中的 VAE 模型出現(xiàn)“KL 散度消失(KL Vanishing)”現(xiàn)象。本文受到參考文獻是 ACL 2020 的論文 A Batch Normalized Inference Network Keeps the KL Vanishing Away [1]?的啟發(fā),并自行做了進一步的完善。
值得一提的是,本文最后得到的方案還是頗為簡潔的——只需往編碼輸出加入BN(Batch Normalization),然后加個簡單的 scale——但確實很有效,因此值得正在研究相關(guān)問題的讀者一試。同時,相關(guān)結(jié)論也適用于一般的 VAE 模型(包括 CV 的),如果按照筆者的看法,它甚至可以作為 VAE 模型的“標配”。
最后,要提醒讀者這算是一篇 VAE 的進階論文,所以請讀者對 VAE 有一定了解后再來閱讀本文。
VAE簡單回顧
這里我們簡單回顧一下 VAE 模型,并且討論一下 VAE 在 NLP 中所遇到的困難。關(guān)于 VAE 的更詳細介紹,請讀者參考筆者的舊作變分自編碼器 VAE:原來是這么一回事、再談變分自編碼器 VAE:從貝葉斯觀點出發(fā)等。
1.1 VAE的訓練流程
VAE 的訓練流程大概可以圖示為:
▲ VAE訓練流程圖示
寫成公式就是:
其中第一項就是重構(gòu)項, 是通過重參數(shù)來實現(xiàn);第二項則稱為 KL 散度項,這是它跟普通自編碼器的顯式差別,如果沒有這一項,那么基本上退化為常規(guī)的 AE。更詳細的符號含義可以參考再談變分自編碼器 VAE:從貝葉斯觀點出發(fā)。
1.2 NLP中的VAE
在 NLP 中,句子被編碼為離散的整數(shù) ID,所以 q(x|z) 是一個離散型分布,可以用萬能的“條件語言模型”來實現(xiàn),因此理論上 q(x|z) 可以精確地擬合生成分布,問題就出在 q(x|z) 太強了,訓練時重參數(shù)操作會來噪聲,噪聲一大,z 的利用就變得困難起來,所以它干脆不要 z 了,退化為無條件語言模型(依然很強), 則隨之下降到 0,這就出現(xiàn)了?KL 散度消失現(xiàn)象。
這種情況下的 VAE 模型并沒有什么價值:KL 散度為 0 說明編碼器輸出的是 0 向量,而解碼器則是一個普通的語言模型。而我們使用 VAE 通常來說是看中了它無監(jiān)督構(gòu)建編碼向量的能力,所以要應用 VAE 的話還是得解決 KL 散度消失問題。
事實上從 2016 開始,有不少工作在做這個問題,相應地也提出了很多方案,比如退火策略、更換先驗分布等,讀者 Google 一下“KL Vanishing”就可以找到很多文獻了,這里不一一溯源。
1.3 BN的巧與秒
本文的方案則是直接針對 KL 散度項入手,簡單有效而且沒什么超參數(shù)。其思想很簡單:
KL 散度消失不就是 KL 散度項變成 0 嗎?我調(diào)整一下編碼器輸出,讓 KL 散度有一個大于零的下界,這樣它不就肯定不會消失了嗎?
這個簡單的思想的直接結(jié)果就是:在 后面加入 BN 層,如圖:
▲ 往VAE里加入BN
1.4 推導過程簡述
為什么會跟 BN 聯(lián)系起來呢?我們來看 KL 散度項的形式:
上式是采樣了 b 個樣本進行計算的結(jié)果,而編碼向量的維度則是 d 維。由于我們總是有 ,所以 ,因此:
留意到括號里邊的量,其實它就是 在 batch 內(nèi)的二階矩,如果我們往 加入 BN 層,那么大體上可以保證 的均值為 ,方差為 ( 是 BN 里邊的可訓練參數(shù)),這時候:
所以只要控制好 (主要是固定 為某個常數(shù)),就可以讓 KL 散度項有個正的下界,因此就不會出現(xiàn) KL 散度消失現(xiàn)象了。這樣一來,KL 散度消失現(xiàn)象跟 BN 就被巧妙地聯(lián)系起來了,通過 BN 來“杜絕”了 KL 散度消失的可能性。
1.5 為什么不是LN?
善于推導的讀者可能會想到,按照上述思路,如果只是為了讓 KL 散度項有個正的下界,其實 LN(Layer Normalization)也可以,也就是在式(3)中按 j 那一維歸一化。
那為什么用BN而不是LN呢?
這個問題的答案也是 BN 的巧妙之處。直觀來理解,KL 散度消失是因為 的噪聲比較大,解碼器無法很好地辨別出 z 中的非噪聲成分,所以干脆棄之不用。
而當給 加上 BN 后,相當于適當?shù)乩_了不同樣本的 z 的距離,使得哪怕 z 帶了噪聲,區(qū)分起來也容易一些,所以這時候解碼器樂意用 z 的信息,因此能緩解這個問題;相比之下,LN 是在樣本內(nèi)進的行歸一化,沒有拉開樣本間差距的作用,所以 LN 的效果不會有 BN 那么好。
進一步的結(jié)果
事實上,原論文的推導到上面基本上就結(jié)束了,剩下的都是實驗部分,包括通過實驗來確定 的值。然而,筆者認為目前為止的結(jié)論還有一些美中不足的地方,比如沒有提供關(guān)于加入 BN 的更深刻理解,倒更像是一個工程的技巧,又比如只是 加上了 BN, 沒有加上,未免有些不對稱之感。
經(jīng)過筆者的推導,發(fā)現(xiàn)上面的結(jié)論可以進一步完善。
2.1 聯(lián)系到先驗分布
對于 VAE 來說,它希望訓練好后的模型的隱變量分布為先驗分布 ,而后驗分布則是 ,所以 VAE 希望下式成立:
兩邊乘以 z,并對 z 積分,得到:
兩邊乘以 ,并對 z 積分,得到:
如果往 都加入 BN,那么我們就有:
所以現(xiàn)在我們知道 一定是 0,而如果我們也固定 ,那么我們就有約束關(guān)系:
2.2 參考的實現(xiàn)方案
經(jīng)過這樣的推導,我們發(fā)現(xiàn)可以往 都加入 BN,并且可以固定 ,但此時需要滿足約束(9)。
要注意的是,這部分討論還僅僅是對 VAE 的一般分析,并沒有涉及到 KL 散度消失問題,哪怕這些條件都滿足了,也無法保證 KL 項不趨于 0。結(jié)合式(4)我們可以知道,保證 KL 散度不消失的關(guān)鍵是確保 ,所以,筆者提出的最終策略是:
其中 是一個常數(shù),筆者在自己的實驗中取了 ,而 是可訓練參數(shù),上式利用了恒等式 。
關(guān)鍵代碼參考(Keras):
class?Scaler(Layer):"""特殊的scale層"""def?__init__(self,?tau=0.5,?**kwargs):super(Scaler,?self).__init__(**kwargs)self.tau?=?taudef?build(self,?input_shape):super(Scaler,?self).build(input_shape)self.scale?=?self.add_weight(name='scale',?shape=(input_shape[-1],),?initializer='zeros')def?call(self,?inputs,?mode='positive'):if?mode?==?'positive':scale?=?self.tau?+?(1?-?self.tau)?*?K.sigmoid(self.scale)else:scale?=?(1?-?self.tau)?*?K.sigmoid(-self.scale)return?inputs?*?K.sqrt(scale)def?get_config(self):config?=?{'tau':?self.tau}base_config?=?super(Scaler,?self).get_config()return?dict(list(base_config.items())?+?list(config.items()))def?sampling(inputs):"""重參數(shù)采樣"""z_mean,?z_std?=?inputsnoise?=?K.random_normal(shape=K.shape(z_mean))return?z_mean?+?z_std?*?noisee_outputs??#?假設(shè)e_outputs是編碼器的輸出向量 scaler?=?Scaler() z_mean?=?Dense(hidden_dims)(e_outputs) z_mean?=?BatchNormalization(scale=False,?center=False,?epsilon=1e-8)(z_mean) z_mean?=?scaler(z_mean,?mode='positive') z_std?=?Dense(hidden_dims)(e_outputs) z_std?=?BatchNormalization(scale=False,?center=False,?epsilon=1e-8)(z_std) z_std?=?scaler(z_std,?mode='negative') z?=?Lambda(sampling,?name='Sampling')([z_mean,?z_std])文章內(nèi)容小結(jié)
本文簡單分析了 VAE 在 NLP 中的 KL 散度消失現(xiàn)象,并介紹了通過 BN 層來防止 KL 散度消失、穩(wěn)定訓練流程的方法。這是一種簡潔有效的方案,不單單是原論文,筆者私下也做了簡單的實驗,結(jié)果確實也表明了它的有效性,值得各位讀者試用。因為其推導具有一般性,所以甚至任意場景(比如 CV)中的 VAE 模型都可以嘗試一下。
參考鏈接
[1] https://arxiv.org/abs/2004.12585
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學習心得或技術(shù)干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創(chuàng)作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發(fā),請在投稿時提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認每篇文章都是首發(fā),均會添加“原創(chuàng)”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發(fā)送?
? 請留下即時聯(lián)系方式(微信或手機),以便我們在編輯發(fā)布時和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術(shù)平臺。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
?
總結(jié)
以上是生活随笔為你收集整理的变分自编码器系列:VAE + BN = 更好的VAE的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 当“塑料小人”站上“大舞台”
- 下一篇: 《永劫无间》手游:目前还没有完成针对鸿蒙