GAN半监督学习
概述
GAN的發(fā)明者Ian Goodfellow2016年在Open AI任職期間發(fā)表了這篇論文,其中提到了GAN用于半監(jiān)督學(xué)習(xí)(semi supervised)的方法。稱為SSGAN。?
作者給出了Theano+Lasagne實(shí)現(xiàn)。本文結(jié)合源碼對(duì)這種方法的推導(dǎo)和實(shí)現(xiàn)進(jìn)行講解。1
半監(jiān)督學(xué)習(xí)
考慮一個(gè)分類問題。?
如果訓(xùn)練集中大部分樣本沒有標(biāo)記類別,只有少部分樣本有標(biāo)記。則需要用半監(jiān)督學(xué)習(xí)(semi-supervised)方法來訓(xùn)練一個(gè)分類器。
wiki上的這張圖很好地說明了無標(biāo)記樣本在半監(jiān)督學(xué)習(xí)中發(fā)揮作用:?
如果只考慮有標(biāo)記樣本(黑白點(diǎn)),純粹使用監(jiān)督學(xué)習(xí)。則得到垂直的分類面。?
考慮了無標(biāo)記樣本(灰色點(diǎn))之后,我們對(duì)樣本的整體分布有了進(jìn)一步認(rèn)識(shí),能夠得到新的、更準(zhǔn)確的分類面。
核心理念
在半監(jiān)督學(xué)習(xí)中運(yùn)用GAN的邏輯如下。
- 無標(biāo)記樣本沒有類別信息,無法訓(xùn)練分類器;
- 引入GAN后,其中生成器(Generator)可以從隨機(jī)信號(hào)生成偽樣本;
- 相比之下,原有的無標(biāo)記樣本擁有了人造類別:真。可以和偽樣本一起訓(xùn)練分類器。?
舉個(gè)通俗的例子:就算沒人教認(rèn)字,多練練分辨“是不是字”也對(duì)認(rèn)字有好處。有粗糙的反饋,也比沒有反饋強(qiáng)。
原理
框架
GAN中的兩個(gè)核心模塊是生成器(Generator)和鑒別器(Discriminator)。這里用分類器(Classifier)代替了鑒別器。?
訓(xùn)練集中包含有標(biāo)簽樣本xlxl和無標(biāo)簽樣本xuxu。?
生成器從隨機(jī)噪聲生成偽樣本IfIf。?
分類器接受樣本II,對(duì)于KK類分類問題,輸出K+1K+1維估計(jì)ll,再經(jīng)過softmax函數(shù)得到概率pp:其前KK維對(duì)應(yīng)原有KK個(gè)類,最后一維對(duì)應(yīng)“偽樣本”類。?
pp的最大值位置對(duì)應(yīng)為估計(jì)標(biāo)簽yy。
三種誤差
整個(gè)系統(tǒng)涉及三種誤差。
對(duì)于訓(xùn)練集中的有標(biāo)簽樣本,考察估計(jì)的標(biāo)簽是否正確。即,計(jì)算分類為相應(yīng)的概率:?
對(duì)于訓(xùn)練集中的無標(biāo)簽樣本,考察是否估計(jì)為“真”。即,計(jì)算不估計(jì)為K+1K+1類的概率:?
對(duì)于生成器產(chǎn)生的偽樣本,考察是否估計(jì)為“偽”。即,計(jì)算估計(jì)為K+1K+1類的概率:?
推導(dǎo)
考慮softmax函數(shù)的一個(gè)特性:?
即,如果輸入各維減去同一個(gè)數(shù),softmax結(jié)果不變。?
于是,可以令 l→l?lK+1l→l?lK+1 ,有 lK+1=0lK+1=0 , p=softmax(l)p=softmax(l) 保持不變。
期望號(hào)略去不寫,利用explK+1=1,exp?lK+1=1,后兩種代價(jià)變?yōu)?#xff1a;?
上述推導(dǎo)可以讓我們省去lK+1lK+1,讓分類器仍然輸出K維的估計(jì)ll。
對(duì)于第一個(gè)代價(jià),由于分類器輸入必定來自前K類,所以可以直接使用ll的前K維:?
引入兩個(gè)函數(shù),使得書寫更為簡(jiǎn)潔:
LSE(x)=ln[∑j=1expxj]LSE(x)=ln?[∑j=1exp?xj]softplus(x)=ln(1+expx)softplus(x)=ln?(1+exp?x)三個(gè)誤差:?
優(yōu)化目標(biāo)
對(duì)于分類器來說,希望上述誤差盡量小。引入權(quán)重ww,得到分類器優(yōu)化目標(biāo):?
對(duì)于生成器來說,希望其輸出的偽樣本能夠騙過分類器。生成器優(yōu)化目標(biāo)與分類器的第三項(xiàng)相反:?
實(shí)驗(yàn)
本文的實(shí)驗(yàn)包含三個(gè)圖像分類問題。分類器接受圖像xx,輸出KK類分類結(jié)果ll。生成器從均勻分布的噪聲zz生成一張圖像xx。
MNIST
10分類問題,圖像為28*28灰度。
生成器是一個(gè)3層線性網(wǎng)絡(luò):?
分類器是一個(gè)6層線性網(wǎng)絡(luò):?
訓(xùn)練樣本60K個(gè),測(cè)試樣本10K個(gè)。?
選擇不同數(shù)量的訓(xùn)練樣本給予標(biāo)記,考察測(cè)試樣本中錯(cuò)誤個(gè)數(shù)。使用不同隨機(jī)數(shù)種子重復(fù)10次:
| 占比 | 0.033% | 0.083% | 0.17% | 0.33% |
| 錯(cuò)誤個(gè)數(shù) | 1677±452 | 221±136 | 93±6.5 | 90±4.2 |
Cifar10
10分類問題,圖像為32*32彩色。
生成器是一個(gè)4層反卷積網(wǎng)絡(luò):?
分類器是一個(gè)9層卷積網(wǎng)絡(luò):?
訓(xùn)練樣本50K個(gè),測(cè)試樣本10K個(gè)。?
選擇不同數(shù)量的訓(xùn)練樣本給予標(biāo)記,考察測(cè)試樣本中錯(cuò)誤個(gè)數(shù)。使用不同的測(cè)試/訓(xùn)練分割重復(fù)10次:
| 占比 | 2% | 4% | 8% | 16% |
| 錯(cuò)誤個(gè)數(shù) | 21.83±2.01 | 19.61±2.09 | 18.63±2.32 | 17.72±1.82 |
SVHN
10分類問題,圖像為32*32彩色。
生成器(上)以及分類器(下)和CIFAR10的結(jié)構(gòu)非常類似。?
訓(xùn)練樣本73K,測(cè)試樣本26K。?
選擇不同數(shù)量的訓(xùn)練樣本給予標(biāo)記,考察測(cè)試樣本中錯(cuò)誤個(gè)數(shù)。使用不同的測(cè)試/訓(xùn)練分割重復(fù)10次:
| 占比 | 0.68% | 1.4% | 2.7% |
| 錯(cuò)誤個(gè)數(shù) | 18.84±4.8 | 8.11±1.3 | 6.16±0.58 |
總結(jié)