Swin Transformer对CNN的降维打击
一、前言
一張圖告訴你Transformer現在是多么的強!幾乎包攬了ADE20K語義分割的前幾名!
該文章詳細解讀Swin-transformer的相關內容以及高明之處。看完學不會,你在評論區打我!CNN已然在計算機視覺領域取得了革命性的成果,擁有著不可撼動的地位。Transformer最初用于NLP領域,但Transformer憑借其強大的特征表征能力,已經在cv領域殺出了一條血路。
paper鏈接:https://arxiv.org/pdf/2103.14030.pdf
代碼鏈接:https://github.com/microsoft/Swin-Transformer
二、Swin Transformer
2.1 背景
Transformer最開始用于NLP領域,但其強大的表征能力讓cv領域的研究人員垂涎欲滴。然而從NLP轉為cv領域,存在兩個天然的問題。
- 1.相較于文本,圖像中像素的分辨率更高
- 2.圖像的視覺實體尺寸之間差異很大
傳統Transformer(例如transformer、ViT等)盡管有強大的特征表達能力,但其巨大計算量的問題讓人望而卻步。與傳統Transformer不同的是,Swin-Transformer解決了Transformer一直飽受詬病的計算量問題。那么,Swin-Transformer是如何解決的計算量問題呢?讓我們繼續往下看吧。
2.2 Architecture概況
學習swin transformer之前,我們首先需要熟知以下幾個概念:
- Resolution:假設一張圖像的分辨率為224x224,這里所說的224就是像素。
- Patch:所謂的Patch就是由多少個像素點構成的,假設一個patch的size為4x4,則這個patch包含16個像素點。
- Window:window的size是由patch決定的,而不是由像素點,假設window的size為7x7,則該window包含49個patch,而不是49個像素點。
在對swin-transformer網絡進行講解之前,我們首先需要明確一點:無論是transformer還是swin-transformer結構,都不會改變輸入的形狀,換句話說,輸入是什么樣,經過transformer或swin-transformer后,輸出跟輸入的形狀是相同的。
一般而言,我拿到一篇論文之后,會首先分析每個塊的輸入輸出是怎樣的,先從整體上對網絡結構把握,然后在慢慢的細化。我們首先來梳理一下swin-transformer每個塊的輸入輸出。
| input image | 224x224x3 | |
| patch partition | 224/4 x 224/4 x 4x4x3 | |
| 1 | linear embedding | 224/4 x 224/4 x 96 |
| 1 | swin transformer | 224/4 x 224/4 x 96 |
| 2 | patch merging | 224/8 x 224/8 x 192 |
| 2 | swin transformer | 224/8 x 224/8 x 192 |
| 3 | patch merging | 224/16 x 224/16 x 192 |
| 3 | swin transformer | 224/16 x 224/16 x 192 |
| 4 | patch merging | 224/32 x 224/32 x 384 |
| 4 | swin transformer | 224/32 x 224/32 x 384 |
從結構圖中可以看到,swin-transformer網絡結構主要包括以下層:
- 1.Patch Partition:將輸入圖像劃分為若干個patch
- 2.Linear Embedding:將輸入圖像映射要任意維度(論文中記為C,即C=96)
- 3.Patch Merging:降低分辨率,擴大感受野,獲得多層次的特征信息,類似于CNN中的pool層
- 4.swin transformer:特征提取及特征表征
2.3 swin-transformer結構解析
到這里我們已經大致了解swin-transformer網絡的基本結構,接下來,跟著我一塊揭開Swin-transformer的真面目吧。一個swin-transformer block由兩個連續的swin-transformer結構組成,兩個結構不同之處在于:第一個結構中使用的是在一個window中計算self-attention,記為W-MSA;第二個結構中使用的是shifted window技術,記為SW-MSA。 在這一章節中,我們重點介紹swin-transformer是如何在一個window中進行self-attention計算的。
假設我們將window size設置為4,則一個window中包含4x4個Patch,如下圖中的Layer l的不重疊窗口劃分結果。但只在window中進行self-attention計算,使得各個windows之間缺乏信息的交互,這限制了swin-transformer的特征表達能力。
為此,swin-transformer的作者提出了top-left的窗口移位方式,如下圖中Layer l+1所示。但這樣的window移位方式增加了window的數量(從2x2->3x3),增加了2.25倍,且window之間的size也不盡相同,這導致無法進行并行計算。
基于上述兩個原因,作者提出了shifted window技術,這也是整篇文章的核心所在。那么shifted window的過程是怎樣的呢?
2.4 shifted window
假設input image的size為224x224,window的size為7x7,patch size為4X4,那么input image包含224/4 x 224/4個patch(56x56),如下圖中的第一張圖。我們將其劃分為不重疊的window,每個window包含7x7個patch,如下圖中的第2張圖。接下來,我們將整張圖像沿主對角線方向移位(floor(M/2),floor(M/2))個patch,這里的M代表window的size,則本例中移位(3,3)個patch,如第3張圖所示。移位后,可以看到,一個window包含4個不同window的部分,如第4張圖所示(藍色網格線)。
我們假設移位后的圖像是如下圖所示的。我們分別對不同的區域進行編碼,為什么要進行編碼呢?這是因為我們對一個window中不同區域Patch進行self-attention計算沒有任何意義。例如,區域3和區域4在原圖中就是兩個不相鄰的區域,本身之間沒有任何的聯系。那么,swin-transformer是如何實現一個window中只有相同區域才進行self-attention計算的呢?
我們以右下角4個均不同的區域為案例進行演示。為簡潔,我們將右下角的一個window進行簡化,由原來的49個patch簡化為4個patch,但過程是相同的。
首先我們根據patch的數量建立一個相關矩陣,本例中patch的數量為4,則建立一個4x4的矩陣,然后將x和y進行相減,相減后,相同區域的結果為0,不同區域的結果我們將其置為負無窮,得到一個mask矩陣,然后與計算得到的attention矩陣進行相加,這樣便實現了相同區域進行self-attention計算。
2.5 Relative position bias
公式中的B即為相對位置信息。那么相對位置信息是如何計算的呢?我們假設有p1、p2、p3、p4四個patch,分別以p1、p2、p3、p4為原點,計算其余patch相對于原點的偏移量,如表1所示。計算完畢后,我們會發現有以下2個問題:
- 1.相對位置信息中出現負數
- 2.(0,1)和(1,0)雖然是2個不同的相對位置信息,但是它們相加的總偏移量相等。
為了解決以上2個問題,論文作者做了如下操作: - 1.為了方便后續計算,每個坐標的位置都加上偏移量,使其從0開始,避免負數的出現。
- 2.對0維度進行乘法變換,論文中是對0維度的數值乘以(2M-1)。
- 3.將0維度和1維度的數值進行相加,得到一個index值。
- 4.根據index的值,映射到權重矩陣中得到相應的權重值。
- 5.將attention矩陣與權重矩陣進行相加。
2.6 循環窗口移動技術是如何實現的
其實原理很簡單,就是使用了torch.roll()這個方法,關于方法的解釋及代碼如下,大家可以了解一下。
總結
以上是生活随笔為你收集整理的Swin Transformer对CNN的降维打击的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习保姆级入门教程 -- 论文+代码
- 下一篇: NAT和PAT