158行代码!程序员复现DeepMind图像生成神器
最近,谷歌 DeepMInd 發(fā)表論文提出了一個(gè)用于圖像生成的遞歸神經(jīng)網(wǎng)絡(luò),該系統(tǒng)大大提高了 MNIST 上生成模型的質(zhì)量。為更加深入了解 DRAW,本文作者基于 Eric Jang 用 158 行 Python 代碼實(shí)現(xiàn)該系統(tǒng)的思路,詳細(xì)闡述了 DRAW 的概念、架構(gòu)和優(yōu)勢(shì)等。
遞歸神經(jīng)網(wǎng)絡(luò)是一種用于圖像生成的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)。Draw Networks 結(jié)合了一種新的空間注意機(jī)制,該機(jī)制模擬了人眼的中心位置,采用了一個(gè)順序變化的自動(dòng)編碼框架,使之對(duì)復(fù)雜圖像進(jìn)行迭代構(gòu)造。
該系統(tǒng)大大提高了 MNIST 上生成模型的質(zhì)量,特別是當(dāng)對(duì)街景房屋編號(hào)數(shù)據(jù)集進(jìn)行訓(xùn)練時(shí),肉眼竟然無(wú)法將它生成的圖像與真實(shí)數(shù)據(jù)區(qū)別開(kāi)來(lái)。
Draw 體系結(jié)構(gòu)的核心是一對(duì)遞歸神經(jīng)網(wǎng)絡(luò):一個(gè)是壓縮用于訓(xùn)練的真實(shí)圖像的編碼器,另一個(gè)是在接收到代碼后重建圖像的解碼器。這一組合系統(tǒng)采用隨機(jī)梯度下降的端到端訓(xùn)練,損失函數(shù)的最大值變分主要取決于對(duì)數(shù)似然函數(shù)的數(shù)據(jù)。
Draw 網(wǎng)絡(luò)類(lèi)似于其他變分自動(dòng)編碼器,它包含一個(gè)編碼器網(wǎng)絡(luò),該編碼器網(wǎng)絡(luò)決定著潛在代碼上的 distribution(潛在代碼主要捕獲有關(guān)輸入數(shù)據(jù)的顯著信息),解碼器網(wǎng)絡(luò)接收來(lái)自 code distribution 的樣本,并利用它們來(lái)調(diào)節(jié)其自身圖像的 distribution 。
DRAW 與其他自動(dòng)解碼器的三大區(qū)別
編碼器和解碼器都是 DRAW 中的遞歸網(wǎng)絡(luò),解碼器的輸出依次添加到 distribution 中以生成數(shù)據(jù),而不是一步一步地生成 distribution 。動(dòng)態(tài)更新的注意機(jī)制用于限制由編碼器負(fù)責(zé)的輸入?yún)^(qū)域和由解碼器更新的輸出區(qū)域 。簡(jiǎn)單地說(shuō),這一網(wǎng)絡(luò)在每個(gè) time-step 都能決定“讀到哪里”和“寫(xiě)到哪里”以及“寫(xiě)什么”。
左:傳統(tǒng)變分自動(dòng)編碼器
在生成過(guò)程中,從先前的 P(z)中提取一個(gè)樣本 z ,并通過(guò)前饋?zhàn)g碼器網(wǎng)絡(luò)來(lái)計(jì)算給定樣本的輸入 P(x_z)的概率。
在推理過(guò)程中,輸入 x 被傳遞到編碼器網(wǎng)絡(luò),在潛在變量上產(chǎn)生一個(gè)近似的后驗(yàn) Q(z|x) 。在訓(xùn)練過(guò)程中,從 Q(z|x) 中抽取 z,然后用它計(jì)算總描述長(zhǎng)度 KL ( Q (Z|x)∣∣ P(Z) log(P(x|z)),該長(zhǎng)度隨隨機(jī)梯度的下降(https://en.wikipedia.org/wiki/Stochastic_gradient_descent)而減小至最小值。
右:DRAW網(wǎng)絡(luò)
在每一個(gè)步驟中,都會(huì)將先前 P(z)中的一個(gè)樣本 z_t 傳遞給遞歸解碼器網(wǎng)絡(luò),該網(wǎng)絡(luò)隨后會(huì)修改 canvas matrix 的一部分。最后一個(gè) canvas matrix cT 用于計(jì)算 P(x|z_1:t)。
在推理過(guò)程中,每個(gè) time-step 都會(huì)讀取輸入,并將結(jié)果傳遞給編碼器 RNN,然后從上一 time-step 中的 RNN 指定讀取位置,編碼器 RNN 的輸出用于計(jì)算該 time-step 的潛在變量的近似后驗(yàn)值。
損失函數(shù)
最后一個(gè) canvas matrix cT 用于確定輸入數(shù)據(jù)的模型 D(X | cT) 的參數(shù)。如果輸入是二進(jìn)制的,D 的自然選擇呈伯努利分布,means 由 σ(cT) 給出。重建損失 Lx 定義為 D 下 x 的負(fù)對(duì)數(shù)概率:
The latent loss 潛在distributions序列
的潛在損失被定義為源自
的潛在先驗(yàn) P(Z_t)的簡(jiǎn)要 KL散度。
鑒于這一損失取決于由
繪制的潛在樣本 z_t ,因此其反過(guò)來(lái)又決定了輸入 x。如果潛在 distribution是一個(gè)這樣的 diagonal Gaussian ,P(Z_t) 便是一個(gè)均值為 0,且具有標(biāo)準(zhǔn)離差的標(biāo)準(zhǔn) Gaussian,這種情況下方程則變?yōu)?/p>
。
網(wǎng)絡(luò)的總損失 L 是重建和潛在損失之和的期望值:
對(duì)于每個(gè)隨機(jī)梯度下降,我們使用單個(gè) z 樣本進(jìn)行優(yōu)化。
L^Z 可以解釋為從之前的序列向解碼器傳輸潛在樣本序列 z_1:T 所需的 NAT 數(shù)量,并且(如果 x 是離散的)L^x 是解碼器重建給定 z_1:T 的 x 所需的 NAT 數(shù)量。因此,總損失等于解碼器和之前數(shù)據(jù)的預(yù)期壓縮量。
改善圖片
正如 EricJang 在他的文章中提到的,讓我們的神經(jīng)網(wǎng)絡(luò)僅僅“改善圖像”而不是“一次完成圖像”會(huì)更容易些。正如人類(lèi)藝術(shù)家在畫(huà)布上涂涂畫(huà)畫(huà),并從繪畫(huà)過(guò)程中推斷出要修改什么,以及下一步要繪制什么。
改進(jìn)圖像或逐步細(xì)化只是一次又一次地破壞我們的聯(lián)合 distribution P(C) ,導(dǎo)致潛在變量鏈 C1,C2,…CT 1 呈現(xiàn)新的變量分布 P(CT) 。
訣竅是多次從迭代細(xì)化分布 P(Ct|Ct 1)中取樣,而不是直接從 P(C) 中取樣。
在 DRAW 模型中, P(Ct|Ct 1) 是所有 t 的同一 distribution,因此我們可以將其表示為以下遞歸關(guān)系(如果不是,那么就是Markov Chain而不是遞歸網(wǎng)絡(luò)了)。
DRAW模型的實(shí)際應(yīng)用
假設(shè)你正在嘗試對(duì)數(shù)字 8 的圖像進(jìn)行編碼。每個(gè)手寫(xiě)數(shù)字的繪制方式都不同,有的樣本 8 可能看起來(lái)寬一些,有的可能長(zhǎng)一些。如果不注意,編碼器將被迫同時(shí)捕獲所有這些小的差異。
但是……如果編碼器可以在每一幀上選擇一小段圖像并一次檢查數(shù)字 8 的每一部分呢?這會(huì)使工作更容易,對(duì)吧?
同樣的邏輯也適用于生成數(shù)字。注意力單元將決定在哪里繪制數(shù)字 8 的下一部分-或任何其他部分-而傳遞的潛在矢量將決定解碼器生成多大的區(qū)域。
基本上,如果我們把變分的自動(dòng)編碼器(VAE)中的潛在代碼看作是表示整個(gè)圖像的矢量,那么繪圖中的潛在代碼就可以看作是表示筆畫(huà)的矢量。最后,這些向量的序列實(shí)現(xiàn)了原始圖像的再現(xiàn)。
好吧,那么它是如何工作的呢?
在一個(gè)遞歸的 VAE 模型中,編碼器在每一個(gè) timestep 會(huì)接收整個(gè)輸入圖像。在 Draw 中,我們需要將焦點(diǎn)集中在它們之間的 attention gate 上,因此編碼器只接收到網(wǎng)絡(luò)認(rèn)為在該 timestep 重要的圖像部分。第一個(gè) attention gate 被稱(chēng)為“Read”attention。
“Read”attention分為兩部分:
選擇圖像的重要部分和裁剪圖像
選擇圖像的重要部分
為了確定圖像的哪一部分最重要,我們需要做些觀察,并根據(jù)這些觀察做出決定。在 DRAW中,我們使用前一個(gè) timestep 的解碼器隱藏狀態(tài)。通過(guò)使用一個(gè)簡(jiǎn)單的完全連接的圖層,我們可以將隱藏狀態(tài)映射到三個(gè)決定方形裁剪的參數(shù):中心 X、中心 Y 和比例。
裁剪圖像
現(xiàn)在,我們不再對(duì)整個(gè)圖像進(jìn)行編碼,而是對(duì)其進(jìn)行裁剪,只對(duì)圖像的一小部分進(jìn)行編碼。然后,這個(gè)編碼通過(guò)系統(tǒng)解碼成一個(gè)小補(bǔ)丁。
現(xiàn)在我們到達(dá) attention gate 的第二部分, “write”attention,(與“read”部分的設(shè)置相同),只是“write”attention 使用當(dāng)前的解碼器,而不是前一個(gè) timestep 的解碼器。
雖然可以直觀地將注意力機(jī)制描述為一種裁剪,但實(shí)踐中使用了一種不同的方法。在上面描述的模型結(jié)構(gòu)仍然精確的前提下,使用了gaussian filters矩陣,沒(méi)有利用裁剪的方式。我們?cè)贒RAW 中取了一組每個(gè) filter 的中心間距都均勻的gaussian filters 矩陣 。
代碼一覽
我們?cè)?Eric Jang 的代碼的基礎(chǔ)上,對(duì)其進(jìn)行一些清理和注釋,以便于理解.
Eric 為我們提供了一些偉大的功能,可以幫助我們構(gòu)建 “read” 和 “write” 注意門(mén)徑,還有過(guò)濾我們將使用的初始狀態(tài)功能,但是首先,我們需要添加新的功能,來(lái)使我們能創(chuàng)建一個(gè)密集層并合并圖像。并將它們保存到本地計(jì)算機(jī)中,以獲取更新的代碼。
現(xiàn)在讓我們把代碼放在一起以便完成。
原文來(lái)源:https://new.qq.com/omn/20190903/20190903A09ORN00.html
總結(jié)
以上是生活随笔為你收集整理的158行代码!程序员复现DeepMind图像生成神器的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: .htaccess:正则表达式、重定向代
- 下一篇: 禁用人脸识别四个月后,旧金山人证明了他们