【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务
大家好,歡迎來(lái)到專欄《百戰(zhàn)GAN》,我們?cè)诠娞?hào)已經(jīng)輸出了非常多的GAN相關(guān)的理論,這一次我們開設(shè)《百戰(zhàn)GAN》專欄,在這個(gè)專欄里,我們會(huì)進(jìn)行算法的核心思想講解,代碼的詳解,模型的訓(xùn)練等內(nèi)容。
作者&編輯 | 言有三
本文資源與生成結(jié)果展示
本文篇幅:5000字
背景要求:會(huì)使用Python,Tensorflow或者Pytorch
附帶資料:項(xiàng)目推薦,版本包括Pytorch+Tensorflow
同步平臺(tái):有三AI知識(shí)星球(一周內(nèi))
1 項(xiàng)目背景
生成對(duì)抗網(wǎng)絡(luò)如今在計(jì)算機(jī)視覺的很多領(lǐng)域中都被廣泛應(yīng)用,需要每一個(gè)學(xué)習(xí)深度學(xué)習(xí)相關(guān)技術(shù)的算法人員掌握,我們公眾號(hào)和知識(shí)星球講述了非常多的理論知識(shí),在這個(gè)《百戰(zhàn)GAN》專欄中,我們會(huì)配合各類實(shí)戰(zhàn)案例來(lái)幫助大家進(jìn)行提升,本次項(xiàng)目開發(fā)需要以下環(huán)境:
(1) Linux系統(tǒng)或者windows系統(tǒng),使用Linux效率更高。
(2)?安裝好的Tensorflow,CPU或者GPU訓(xùn)練都可以。
2 原理簡(jiǎn)介
今天我們要實(shí)踐的模型是DCGAN和CGAN,DCGAN是第一個(gè)全卷積GAN,麻雀雖小,五臟俱全,最適合新人實(shí)踐。
DCGAN的生成器和判別器都采用了4層的網(wǎng)絡(luò)結(jié)構(gòu)。生成器網(wǎng)絡(luò)結(jié)構(gòu)如上圖所示,輸入為1×100的向量,然后經(jīng)過一個(gè)全連接層學(xué)習(xí),reshape為4×4×1024的張量,再經(jīng)過4個(gè)上采樣的反卷積網(wǎng)絡(luò)層,生成64×64的圖,各層的配置如下:
判別器輸入64×64大小的圖,經(jīng)過4次卷積,分辨率降低為4×4的大小,每一個(gè)卷積層的配置如下:
DCGAN并不能控制生成圖片的類別,條件GAN(CGAN)則使用了條件控制變量作為輸入,是幾乎后續(xù)所有性能強(qiáng)大的GAN的基礎(chǔ)。網(wǎng)絡(luò)結(jié)構(gòu)如下,其中的y就是條件變量。
對(duì)于生成器來(lái)說,輸入包括z和y,兩者會(huì)進(jìn)行拼接后作為輸入。對(duì)于判別器來(lái)說,輸入包括了x和y,兩者會(huì)進(jìn)行拼接后作為輸入,當(dāng)然為了和z以及x進(jìn)行拼接,y需要做一些維度變換,即reshape操作。
關(guān)于它們的理論更加詳細(xì)的講解,大家可以移步有三AI知識(shí)星球,或者自行閱讀論文。
3 模型訓(xùn)練
接下來(lái)我們進(jìn)行實(shí)踐,選擇tensorflow框架,下面詳解具體的工程代碼,主要包括:
(1) 生成器和判別器模型的定義。
(2) 損失和優(yōu)化目標(biāo)的定義。
3.1 DCGAN類定義
首先我們需要定義一個(gè)類,設(shè)計(jì)好輸入輸出,__init__函數(shù)如下:
# 模型定義
class DCGAN(object):
? ? def __init__(self, sess, input_height=108, input_width=108, crop=True,
???????? batch_size=64, sample_num = 64, output_height=64, output_width=64,
???????? y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
???????? gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
???????? max_to_keep=1,
???????? input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
其中參數(shù)解釋如下:sess表示TensorFlow session,batch_size即批處理大小;z_dim是噪聲的維度,默認(rèn)為100;y_dim是一個(gè)可選的條件變量,比如分類標(biāo)簽,用于CGAN;gf_dim是生成器第一個(gè)卷積層的通道數(shù);df_dim是判別器第一個(gè)卷積層的通道數(shù);gfc_dim是生成器全連接層維度;dfc_dim是判別器全連接層維度;c_dim是輸入圖像維度,灰度圖為1,彩色圖為3。
從上述代碼可以看出,初始化函數(shù)__init__中配置了訓(xùn)練輸入圖尺寸,批處理大小,輸出圖尺寸,生成器的輸入維度,以及生成器和判別的卷積層和全連接層的若干維度變量。
總結(jié)
以上是生活随笔為你收集整理的【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【通知】3月第三周直播预告,模型精简前沿
- 下一篇: 推荐 | 有三AI生态新的资源干货集中营