巧断梯度:单个loss实现GAN模型(附开源代码)
作者丨蘇劍林
單位丨廣州火焰信息科技有限公司
研究方向丨NLP,神經網絡
個人主頁丨kexue.fm
我們知道普通的模型都是搭好架構,然后定義好 loss,直接扔給優化器訓練就行了。但是 GAN 不一樣,一般來說它涉及有兩個不同的 loss,這兩個 loss 需要交替優化。
現在主流的方案是判別器和生成器都按照 1:1 的次數交替訓練(各訓練一次,必要時可以給兩者設置不同的學習率,即 TTUR),交替優化就意味我們需要傳入兩次數據(從內存傳到顯存)、執行兩次前向傳播和反向傳播。
如果我們能把這兩步合并起來,作為一步去優化,那么肯定能節省時間的,這也就是 GAN 的同步訓練。
注:本文不是介紹新的 GAN,而是介紹 GAN 的新寫法,這只是一道編程題,不是一道算法題。
如果在TF中
如果是在 TensorFlow 中,實現同步訓練并不困難,因為我們定義好了判別器和生成器的訓練算子了(假設為 D_solver 和 G_solver ),那么直接執行:
就行了。這建立在我們能分別獲取判別器和生成器的參數、能直接操作 sess.run 的基礎上。
更通用的方法?
但是如果是 Keras 呢?Keras 中已經把流程封裝好了,一般來說我們沒法去操作得如此精細。
所以,下面我們介紹一個通用的技巧,只需要定義單一一個 loss,然后扔給優化器,就能夠實現 GAN 的訓練。同時,從這個技巧中,我們還可以學習到如何更加靈活地操作 loss 來控制梯度。
判別器的優化
我們以 GAN 的 hinge loss 為例子,它的形式是:
注意意味著要固定 G,因為 G 本身也是有優化參數的,不固定的話就應該是。
為了固定G,除了“把 G 的參數從優化器中去掉”這個方法之外,我們也可以利用 stop_gradient 去手動固定:
這里:
這樣一來,在式 (2) 中,我們雖然同時放開了 D,G 的權重,但是不斷地優化式 (2),會變的只有 D,而 G 是不會變的,因為我們用的是基于梯度下降的優化器,而 G 的梯度已經被停止了,換句話說,我們可以理解為 G 的梯度被強行設置為 0,所以它的更新量一直都是 0。?
生成器的優化
現在解決了 D 的優化,那么 G 呢? stop_gradient 可以很方便地放我們固定里邊部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的優化是要我們去固定外邊的 D,沒有函數實現它。但不要灰心,我們可以用一個數學技巧進行轉化。?
首先,我們要清楚,我們想要 D(G(z)) 里邊的 G 的梯度,不想要 D 的梯度,如果直接對 D(G(z)) 求梯度,那么同時會得到 D,G 的梯度。如果直接求的梯度呢?只能得到 D 的梯度,因為 G 已經被停止了。那么,重點來了,將這兩個相減,不就得到單純的 G 的梯度了嗎!
現在優化式 (4) ,那么 D 是不會變的,改變的是 G。?
值得一提的是,直接輸出這個式子,結果是恒等于 0,因為兩部分都是一樣的,直接相減自然是 0,但它的梯度不是 0。也就是說,這是一個恒等于 0 的 loss,但是梯度卻不恒等于 0。?
合成單一loss?
好了,現在式 (2) 和式 (4) 都同時放開了 D,G,大家都是 arg min,所以可以將兩步合成一個 loss:
寫出這個 loss,就可以同時完成判別器和生成器的優化了,而不需要交替訓練,但是效果基本上等效于 1:1 的交替訓練。引入 λ 的作用,相當于讓判別器和生成器的學習率之比為 1:λ。
參考代碼:
https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py
文章小結
文章主要介紹了實現 GAN 的一個小技巧,允許我們只寫單個模型、用單個 loss 就實現 GAN 的訓練。它本質上就是用 stop_gradient 來手動控制梯度的技巧,在其他任務上也可能用得到它。
所以,以后我寫 GAN 都用這種寫法了,省力省時。當然,理論上這種寫法需要多耗些顯存,這也算是犧牲空間換時間吧。
點擊以下標題查看作者其他文章:?
變分自編碼器VAE:原來是這么一回事 | 附開源代碼
再談變分自編碼器VAE:從貝葉斯觀點出發
變分自編碼器VAE:這樣做為什么能成?
從變分編碼、信息瓶頸到正態分布:論遺忘的重要性
深度學習中的互信息:無監督提取特征
全新視角:用變分推斷統一理解生成模型
細水長flow之NICE:流模型的基本概念與實現
細水長flow之f-VAEs:Glow與VAEs的聯姻
深度學習中的Lipschitz約束:泛化與生成模型
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢??答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
??來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
?
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點擊 |?閱讀原文?| 查看作者博客
總結
以上是生活随笔為你收集整理的巧断梯度:单个loss实现GAN模型(附开源代码)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: AAAI 2019 论文解读 | 基于区
- 下一篇: Self-Attention GAN 中