训练GAN的技巧
GAN, 作為一種非常厲害的生成模型, 在近年來得到了廣泛的應用. Soumith, PyTorch之父, 畢業于紐約大學的Facebook的VP, 在2015年發明了DCGAN: Deep Convolutional GAN. 它顯式的使用卷積和轉置卷積在判別器和生成器中使用. 他對GAN的理解相對深入, 特地總結了關于訓練GAN的一些技巧和方式, 因為不同于一般任務, 像設置優化器, 計算loss以及初始化模型權重等tips, 這些對于GAN網絡能否收斂可以說至關重要. 現在特此翻譯這篇文章, 以饗讀者.
How to Train a GAN? Tips and tricks to make GANs work
隨著人們對生成對抗網絡(GANs)的研究進一步深入, 繼續提高GAN的基本穩定性是非常重要的一環。我們使用了一系列技巧來訓練它們,使它們保持穩定。
作者:
Soumith Chintala, Emily Denton, Martin Arjovsky, Michael Mathieu.
廢話不多說, 直接上干貨:
1. Normalize the inputs( 規則化輸入數據)
- 將圖像的數值范圍限制在 [-1, 1].
- 將Tanh層作為生成器最后輸出層.
2. A modified loss function(修改Loss函數)
在GAN的論文中, 生成器G的目標是使得目標函數log(1?D)log(1?D)log(1?D)log(1?D)log (1-D)log(1?D)log(1?D)log(1?D)log(1?D)最小, 但是實際寫代碼中, 目標是讓log(D)log(D)log(D)log(D)log(D)log(D)log(D)log(D)log(D)最大. 這是因為前面的式子有梯度消失問題. Goodfellow et. al (2014)
此外, 訓練生成器的時候, 還可以將數據對應的**標簽(label)**進行翻轉: 即real = fake, fake = real來進行訓練. 其目的是增強生成器的泛化能力(通常作為在生成器能力很強的時候fine-tune的策略.)
3. Use a spherical Z(使用球面分布)
通常的GAN中,包括2019年最新的styleGAN,它們的latent vector z都是通過正態分布進行采樣得到的(根據情況,可能是非標準正態分布)。
本文推薦對高斯分布(gaussian distribution)進行采樣而得到 z。
此外, Soumith還指出需要注意以下2點:
- When doing interpolations, do the interpolation via a great circle, rather than a straight line from point A to point B.
- Tom White’s Sampling Generative Networks ref code https://github.com/dribnet/plat has more details.
4. Batch Norm(批歸一化)
- 如果要用BN的話,只能在all-fake或all-real的mini-batch中使用。
- 現在流行使用PixelNorm和InstanceNorm[3](關于其PyTorch1.0.1的實現,在我復現StyleGAN的代碼中有,歡迎參考~)
5. Avoid Sparse Gradients: ReLU, MaxPool(避免稀疏梯度)
如果你使用了ReLU或MaxPooling,那么這樣的GAN通常穩定性較差(由于梯度的稀疏性)。
- LeakyReLU = good (in both G and D) 也是目前幾乎所有GAN的標配。
- 下采樣時,建議使用Average Pooling或者Conv2d + stride的方式。
- 上采樣通常使用PixelShuffle[4]和ConvTranspose2d + stride。
6. Use Soft and Noisy Labels(使用光滑和帶噪聲的標簽)
-
對標簽進行平滑, i.e. if you have two target labels: Real=1 and Fake=0, then for each incoming sample, if it is real, then replace the label with a random number between 0.7 and 1.2, and if it is a fake sample, replace it with 0.0 and 0.3 (for example). Salimans et. al. 2016
-
在訓練鑒別器的時候,偶爾翻轉label,即fake->real, real->fake。
7. DCGAN / Hybrid Models(DCGAN和混合模型)
這里,Soumith開始推銷自己的工作了哈哈,他認為DCGAN在任何場景都能很好的工作。
當然,如果你愿意的話,也可以使用Hybrid的模型,比如 KL + GAN 或 VAE + GAN。
8. Use stability tricks from RL(使用增強學習中提升穩定性的策略)
- 經驗重播(Experience Replay)
- 所有穩定性技巧都適用于深度確定性策略梯度。
- 查閱 Pfau & Vinyals (2016)發表的資料。
9. Use the ADAM Optimizer(使用ADAM優化器)
Soumith認為Adam很吊,一個就夠了。大多數情況,生成器和判別器都用ADAM就可以,或者,你也可以使用SGD來優化判別器。
10. Track failures early(及早發現失敗)
-
① 當判別器的loss一直接近0或者為0的時候,那就說明這次訓練是有問題的,應該及時停掉,檢查模型和超參數的設置。
-
② 檢查梯度的范數,如果超過100,就會出錯。
-
③ 當模型正常訓練時,判別器D的loss方差較小,并且隨著時間的推移而下降,或者方差較大且呈峰值。
-
④ 如果生成器G的loss穩步下降,那可能意味著在用垃圾來迷糊判別器D(馬丁說)。
11. Dont balance loss via statistics (unless you have a good reason to)(不要通過統計數字來平衡損失(除非你有充分的理由))
不要通過設計通過判斷loss是否達到我們預設的閾值來進行觸發訓練。Soumith他們已經試了很多了,不好使。如果,你一定要這么做,那么要有自己的方法論,而不能憑直覺。
while lossD > A:train D while lossG > B:train G12. If you have labels, use them (如果有標簽,那么使用它)
如果你有標簽可用,訓練鑒別器D分類樣本: 輔助GANs。
13. Add noise to inputs, decay over time (對輸入加入噪聲)
這個策略在18,19年Nvidia的大神kerras的論文中都體現的淋漓盡致:
- 對喂入判別器的輸入加入人為噪聲。 (Arjovsky et. al., Huszar, 2016)
http://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
https://openreview.net/forum?id=Hk4_qw5xe - 對生成器G的每層都加入高斯噪聲。 (Zhao et. al. EBGAN)
Improved GANs: OpenAI code also has it (commented out)
14. [notsure] Train discriminator more (sometimes) (使用多個判別器)
- 對數據噪聲比較大的情況適用(ECCV2018 MMAN做Human parsing的一篇文章)。
- 當難以確定生成器和判別器迭代策略的時候。
15. [notsure] Batch Discrimination (批量判別?)
- 將結果混合?
16: Discrete variables in Conditional GANs(CGAN中的離散變量)
- 使用 Embedding 層。
- 為圖像增加額外通道。
- 保持較低的嵌入維數,通過上采樣以匹配圖像通道大小。
17. Use Dropouts in G in both train and test phase (在生成器中使用Dropout)
- 以Dropout的形式模擬提供噪聲 (50%)。
- 無論是訓練還是推理階段,都在生成器G中的某幾層加入Dropout機制[5]。
參考資料
[1] Soumith: How to Train a GAN? Tips and tricks to make GANs work
[2] Tom White: Sampling Generative Networks
[3] 基于PyTorch1.x復現的styleGAN
[4] PixelShuffle
[5] https://arxiv.org/pdf/1611.07004v1.pdf
[6] https://blog.csdn.net/g11d111/article/details/89100833
總結
- 上一篇: 陶哲轩实分析命题 11.10.7
- 下一篇: 【转】.NET NPOI操作Excel常