如何训练GAN
如何訓(xùn)練GAN,FAIR的研究員Soumith Chintala總結(jié)了訓(xùn)練GAN的16個(gè)技巧,例如輸入的規(guī)范化,修改損失函數(shù),生成器用Adam優(yōu)化,使用Sofy和Noisy標(biāo)簽,等等。這是NIPS 2016的Soumith Chintala作的邀請(qǐng)演講的修改版本。
ICCV 2017 slides:https://github.com/soumith/talks/blob/master/2017-ICCV_Venice/How_To_Train_a_GAN.pdf
NIPS2016:https://github.com/soumith/ganhacks
訓(xùn)練GAN的16個(gè)trick
# 1:規(guī)范化輸入
-
將輸入圖像規(guī)范化為-1到1之間
-
生成器最后一層的輸出使用tanh函數(shù)(或其他bounds normalization)
#2:修改損失函數(shù)(經(jīng)典GAN)
-
在GAN論文里人們通常用 min (log 1-D) 這個(gè)損失函數(shù)來優(yōu)化G,但在實(shí)際訓(xùn)練的時(shí)候可以用max log D
-因?yàn)榈谝粋€(gè)公式早期有梯度消失的問題
- Goodfellow et. al (2014)
-
在實(shí)踐中:訓(xùn)練G時(shí)使用反轉(zhuǎn)標(biāo)簽?zāi)芄ぷ鞯煤芎?#xff0c;即:real = fake, fake = real
一些GAN變體
【TensorFlow】https://github.com/hwalsuklee/tensorflow-generative-model-collections
【Pytorch】https://github.com/znxlwm/pytorch-generative-model-collections
#3:使用一個(gè)具有球形結(jié)構(gòu)的噪聲z
-
在做插值(interpolation)時(shí),在大圓(great circle)上進(jìn)行
-
Tom White的論文“Sampling Generative Networks”
- https://arxiv.org/abs/1609.04468
#4: BatchNorm
-
一個(gè)mini-batch里面必須保證只有Real樣本或者Fake樣本,不要把它們混起來訓(xùn)練
-
如果不能用batchnorm,可以用instance norm
#5:避免稀疏梯度:ReLU, MaxPool
-
GAN的穩(wěn)定性會(huì)因?yàn)橐肓讼∈杼荻仁艿接绊?/p>
-
LeakyReLU很好(對(duì)于G和D)
-
對(duì)于下采樣,使用:Average Pooling,Conv2d + stride
-
對(duì)于上采樣,使用:PixelShuffle, ConvTranspose2d + stride
-PixelShuffle 論文:https://arxiv.org/abs/1609.05158
#6:使用Soft和Noisy標(biāo)簽
-
Label平滑,也就是說,如果有兩個(gè)目標(biāo)label:Real=1 和 Fake=0,那么對(duì)于每個(gè)新樣本,如果是real,那么把label替換為0.7~1.2之間的隨機(jī)值;如果樣本是fake,那么把label替換為0.0~0.3之間的隨機(jī)值。
-
訓(xùn)練D時(shí),有時(shí)候可以使這些label是噪聲:偶爾翻轉(zhuǎn)label
- Salimans et. al. 2016
#7:架構(gòu):DCGANs / Hybrids
-
能用DCGAN就用DCGAN,
-
如果用不了DCGAN而且沒有穩(wěn)定的模型,可以使用混合模型:KL + GAN 或 VAE + GAN
-
WGAN-gp的ResNet也很好(但非常慢)
- https://github.com/igul222/improved_wgan_training
-
width比depth更重要
#8:借用RL的訓(xùn)練技巧
-
Experience replay
-
對(duì)于deep deterministic policy gradients(DDPG)有效的技巧
-
參考Pfau & Vinyals (2016)的論文
#9:優(yōu)化器:ADAM
-
優(yōu)化器用Adam(Radford et. al. 2015)
-
或者對(duì)D用SGD,G用Adam
#10:使用 Gradient Penalty
-
使梯度的norm規(guī)范化
-
對(duì)于為什么這一點(diǎn)有效,有多個(gè)理論(WGAN-GP, DRAGAN, 通過規(guī)范化使GAN穩(wěn)定)
#11:不要通過loss statistics去balance G與D的訓(xùn)練過程(經(jīng)典GAN)
#12:如果你有類別標(biāo)簽,請(qǐng)使用它們
-
如果還有可用的類別標(biāo)簽,在訓(xùn)練D判別真?zhèn)蔚耐瑫r(shí)對(duì)樣本進(jìn)行分類
#13:給輸入增加噪聲,隨時(shí)間衰減
-
給D的輸入增加一些人工噪聲(Arjovsky et. al., Huszar, 2016)
-
給G的每一層增加一些高斯噪聲(Zhao et. al. EBGAN)
#14:多訓(xùn)練判別器D
-
特別是在加噪聲的時(shí)候
#15:避開離散空間
-
將生成結(jié)果作為一個(gè)連續(xù)預(yù)測(cè)
#16:離散變量
-
使用一個(gè)嵌入層
-
給圖像增加額外通道
-
保持嵌入的維度低和上采樣以匹配圖像通道的大小
總結(jié):
-
GAN模型的穩(wěn)定性在提升
-
理論研究有所進(jìn)展
-
技巧只是權(quán)宜之計(jì)
時(shí)間線——GAN模型的穩(wěn)定性
PPT下載:https://github.com/soumith/talks/blob/master/2017-ICCV_Venice/How_To_Train_a_GAN.pdf
參考:https://github.com/soumith/ganhacks
總結(jié)
- 上一篇: SBO业务单据类型(总结)
- 下一篇: HDU-2079 选课时间(题目已修改,