Self-Attention GAN 中的 self-attention 机制
作者丨尹相楠
學校丨里昂中央理工博士在讀
研究方向丨人臉識別、對抗生成網絡
Self Attention GAN 用到了很多新的技術。最大的亮點當然是 self-attention 機制,該機制是 Non-local Neural Networks?[1] 這篇文章提出的。其作用是能夠更好地學習到全局特征之間的依賴關系。因為傳統的 GAN 模型很容易學習到紋理特征:如皮毛,天空,草地等,不容易學習到特定的結構和幾何特征,例如狗有四條腿,既不能多也不能少。?
除此之外,文章還用到了 Spectral Normalization for GANs [2]?提出的譜歸一化。譜歸一化的解釋見本人這篇文章:詳解GAN的譜歸一化(Spectral Normalization)。
但是,該文代碼中的譜歸一化和原始的譜歸一化運用方式略有差別:?
1. 原始的譜歸一化基于 W-GAN 的理論,只用在 Discriminator 中,用以約束 Discriminator 函數為 1-Lipschitz 連續。而在 Self-Attention GAN 中,Spectral Normalization 同時出現在了 Discriminator 和 Generator 中,用于使梯度更穩定。除了生成器和判別器的最后一層外,每個卷積/反卷積單元都會上一個 SpectralNorm。?
2. 當把譜歸一化用在 Generator 上時,同時還保留了 BatchNorm。Discriminator 上則沒有 BatchNorm,只有 SpectralNorm。?
3. 譜歸一化用在 Discriminator 上時最后一層不加 Spectral Norm。?
最后,self-attention GAN 還用到了 cGANs With Projection Discriminator 提出的 conditional normalization 和 projection in the discriminator。這兩個技術我還沒有來得及看,而且 PyTorch 版本的 self-attention GAN 代碼中也沒有實現,就先不管它們了。
本文主要說的是 self-attention 這部分內容。
▲?圖1.?Self-Attention
Self-Attention
在卷積神經網絡中,每個卷積核的尺寸都是很有限的(基本上不會大于 5),因此每次卷積操作只能覆蓋像素點周圍很小一塊鄰域。
對于距離較遠的特征,例如狗有四條腿這類特征,就不容易捕獲到了(也不是完全捕獲不到,因為多層的卷積、池化操作會把 feature map 的高和寬變得越來越小,越靠后的層,其卷積核覆蓋的區域映射回原圖對應的面積越大。但總而言之,畢竟還得需要經過多層映射,不夠直接)。
Self-Attention 通過直接計算圖像中任意兩個像素點之間的關系,一步到位地獲取圖像的全局幾何特征。?
論文中的公式不夠直觀,我們直接看文章的 PyTorch 的代碼,核心部分為 sagan_models.py:
????"""?Self?attention?Layer"""
????def?__init__(self,in_dim,activation):
????????super(Self_Attn,self).__init__()
????????self.chanel_in?=?in_dim
????????self.activation?=?activation
????????self.query_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim//8?,?kernel_size=?1)
????????self.key_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim//8?,?kernel_size=?1)
????????self.value_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim?,?kernel_size=?1)
????????self.gamma?=?nn.Parameter(torch.zeros(1))
????????self.softmax??=?nn.Softmax(dim=-1)?#
????def?forward(self,x):
????????"""
????????????inputs?:
????????????????x?:?input?feature?maps(?B?X?C?X?W?X?H)
????????????returns?:
????????????????out?:?self?attention?value?+?input?feature?
????????????????attention:?B?X?N?X?N?(N?is?Width*Height)
????????"""
????????m_batchsize,C,width?,height?=?x.size()
????????proj_query??=?self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)?#?B?X?CX(N)
????????proj_key?=??self.key_conv(x).view(m_batchsize,-1,width*height)?#?B?X?C?x?(*W*H)
????????energy?=??torch.bmm(proj_query,proj_key)?#?transpose?check
????????attention?=?self.softmax(energy)?#?BX?(N)?X?(N)?
????????proj_value?=?self.value_conv(x).view(m_batchsize,-1,width*height)?#?B?X?C?X?N
????????out?=?torch.bmm(proj_value,attention.permute(0,2,1)?)
????????out?=?out.view(m_batchsize,C,width,height)
????????out?=?self.gamma*out?+?x
????????return?out,attention
構造函數中定義了三個 1?× 1 的卷積核,分別被命名為 query_conv , key_conv 和 value_conv 。
為啥命名為這三個名字呢?這和作者給它們賦予的含義有關。query 意為查詢,我們希望輸入一個像素點,查詢(計算)到 feature map 上所有像素點對這一點的影響。而 key 代表字典中的鍵,相當于所查詢的數據庫。query 和 key 都是輸入的 feature map,可以看成把 feature map 復制了兩份,一份作為 query 一份作為 key。?
需要用一個什么樣的函數,才能針對 query 的 feature map 中的某一個位置,計算出 key 的 feature map 中所有位置對它的影響呢?作者認為這個函數應該是可以通過“學習”得到的。那么,自然而然就想到要對這兩個 feature map 分別做卷積核為 1?× 1 的卷積了,因為卷積核的權重是可以學習得到的。?
至于 value_conv ,可以看成對原 feature map 多加了一層卷積映射,這樣可以學習到的參數就更多了,否則 query_conv 和 key_conv 的參數太少,按代碼中只有 in_dims × in_dims//8 個。?
接下來逐行研究 forward 函數:
這行代碼先對輸入的 feature map 卷積了一次,相當于對 query feature map 做了一次投影,所以叫做 proj_query。由于是 1?× 1 的卷積,所以不改變 feature map 的長和寬。feature map 的每個通道為如 (1) 所示的矩陣,矩陣共有 N 個元素(像素)。
然后重新改變了輸出的維度,變成:
?(m_batchsize,-1,width*height)?
batch size 保持不變,width 和 height 融合到一起,把如 (1) 所示二維的 feature map 每個 channel 拉成一個長度為 N 的向量。
因此,如果 m_batchsize 取 1,即單獨觀察一個樣本,該操作的結果是得到一個矩陣,矩陣的的行數為 query_conv 卷積輸出的 channel 的數目 C( in_dim//8 ),列數為 feature map 像素數 N。
然后作者又通過 .permute(0, 2, 1) 轉置了矩陣,矩陣的行數變成了 feature map 的像素數 N,列數變成了通道數 C。因此矩陣維度為 N?× C 。該矩陣每行代表一個像素位置上所有通道的值,每列代表某個通道中所有的像素值。
▲?圖2.?proj_query 的維度
這行代碼和上一行類似,只不過取消了轉置操作。得到的矩陣行數為通道數 C,列數為像素數 N,即矩陣維度為 C?× N。該矩陣每行代表一個通道中所有的像素值,每列代表一個像素位置上所有通道的值。
▲?圖3. proj_key的維度
這行代碼中, torch.bmm 的意思是 batch matrix multiplication。就是說把相同 batch size 的兩組 matrix 一一對應地做矩陣乘法,最后得到同樣 batchsize 的新矩陣。
若 batch size=1,就是普通的矩陣乘法。已知 proj_query 維度是 N?× C, proj_key 的維度是 C?×?N,因此 energy 的維度是 N?× N:
▲?圖4. energy的維度
energy 是 attention 的核心,其中第 i 行 j 列的元素,是由 proj_query 第 i 行,和 proj_key 第 j 列通過向量點乘得到的。而 proj_query 第 i 行表示的是 feature map 上第 i 個像素位置上所有通道的值,也就是第 i 個像素位置的所有信息,而 proj_key 第 j 列表示的是 feature map 上第 j 個像素位置上的所有通道值,也就是第 j 個像素位置的所有信息。
這倆相乘,可以看成是第 j 個像素對第 i 個像素的影響。即,energy 中第 i 行 j 列的元素值,表示第 j 個像素點對第 i 個像素點的影響。
這里 sofmax 是構造函數中定義的,為按“行”歸一化。這個操作之后的矩陣,各行元素之和為 1。這也比較好理解,因為 energy 中第 i 行元素,代表 feature map 中所有位置的像素對第 i 個像素的影響,而這個影響被解釋為權重,故加起來應該是 1,故應對其按行歸一化。attention 的維度也是 N?× N。
上面的代碼中,先對原 feature map 作一次卷積映射,然后把得到的新 feature map 改變形狀,維度變為 C?×?N ,其中 C 為通道數(注意和上面計算 proj_query???proj_key 的 C 不同,上面的 C 為 feature map 通道數的 1/8,這里的 C 與 feature map 通道數相同),N 為 feature map 的像素數。
▲?圖5.?proj_value的維度
out?=?out.view(m_batchsize,C,width,height)
然后,再把 proj_value (C?× N)矩陣同? attention 矩陣的轉置(N?× N)相乘,得到 out (C?× N)。之所以轉置,是因為 attention 中每行的和為 1,其意義是權重,需要轉置后變為每列的和為 1,施加于 proj_value 的行上,作為該行的加權平均。 proj_value 第 i 行代表第 i 個通道所有的像素值, attention 第 j 列,代表所有像素施加到第 j 個像素的影響。
因此, out 中第 i 行包含了輸出的第 i 個通道中的所有像素,第 j 列表示所有像素中的第 j 個像素,合起來也就是: out 中的第 i 行第 j 列的元素,表示被 attention 加權之后的 feature map 的第 i 個通道的第 j 個像素的像素值。再改變一下形狀, out 就恢復了 channel×width×height 的結構。
▲?圖6.?out的維度
最后一行代碼,借鑒了殘差神經網絡(residual neural networks)的操作, gamma 是一個參數,表示整體施加了 attention 之后的 feature map 的權重,需要通過反向傳播更新。而 x 就是輸入的 feature map。
在初始階段, gamma 為 0,該 attention 模塊直接返回輸入的 feature map,之后隨著學習,該 attention 模塊逐漸學習到了將 attention 加權過的 feature map 加在原始的 feature map 上,從而強調了需要施加注意力的部分 feature map。
總結
可以把 self attention 看成是 feature map 和它自身的轉置相乘,讓任意兩個位置的像素直接發生關系,這樣就可以學習到任意兩個像素之間的依賴關系,從而得到全局特征了。看論文時會被它復雜的符號迷惑,但是一看代碼就發現其實是很 naive 的操作。
參考文獻
[1] Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He, Non-local Neural Networks, CVPR 2018.
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, Spectral Normalization for Generative Adversarial Networks, ICLR 2018.
點擊以下標題查看更多往期內容:?
Airbnb實時搜索排序中的Embedding技巧
圖神經網絡綜述:模型與應用
近期值得讀的10篇GAN進展論文
自然語言處理中的語言模型預訓練方法
從傅里葉分析角度解讀深度學習的泛化能力
深度思考 | 從BERT看大規模數據的無監督利用
AI Challenger 2018 機器翻譯參賽總結
小米拍照黑科技:基于NAS的圖像超分辨率算法
異構信息網絡表示學習論文解讀
不懂Photoshop如何P圖?交給深度學習吧
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢??答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
??來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
?
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點擊 |?閱讀原文?| 獲取最新論文推薦
總結
以上是生活随笔為你收集整理的Self-Attention GAN 中的 self-attention 机制的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 巧断梯度:单个loss实现GAN模型(附
- 下一篇: 你不是一个人在战斗!有人将吴恩达的视频教