ICCV 2021 | 通过显式寻找物体的extremity区域加快DETR的收敛
本文將解讀筆者發(fā)表在 ICCV 2021的工作。我們針對 DEtection Transformer (DETR) 訓(xùn)練收斂慢的問題(需要訓(xùn)練 500 epoch 才能獲得比較好的效果)提出了 conditional cross-attention mechanism,通過 conditional spatial query 顯式地尋找物體的 extremity 區(qū)域,從而縮小搜索物體的范圍,加速了收斂。結(jié)構(gòu)上只需要對 DETR 的 cross-attention 部分做微小的改動,就能將收斂速度提高 6-10 倍。
?作者?|?Charles
單位?|?微軟亞洲研究院實習(xí)生
研究方向?|?計算機視覺
論文標(biāo)題:
Conditional DETR for Fast Training Convergence
論文鏈接:
https://arxiv.org/pdf/2108.06152.pdf
代碼鏈接:
https://github.com/Atten4Vis/ConditionalDETR
背景和動機
1.1 DETR 簡介
最近提出的 DETR 成功地將 transformer 引入到物體檢測任務(wù)中,獲得了很不錯的性能。DETR 的重要意義在于去除了物體檢測算法里需要人工設(shè)計的部分,比如 anchor 的生成和 NMS 操作。這大大簡化了物體檢測的設(shè)計流程。
DETR 由 CNN backbone,transformer encoder,transformer decoder 和 prediction heads 組成:
1. CNN backbone 提取圖像的 feature;
2. Encoder 通過 self-attention 建模全局關(guān)系對 feature 進行增強;
3. Decoder 主要包含 self-attention 和 cross-attention。Cross- attention 中有若干 queries,每個 query 去由 encoder feature 構(gòu)造的 key 中進行查詢,找到跟物體有關(guān)的區(qū)域,將這些區(qū)域的 feature 提取出來。Self-attention 則在不同的 query 之間進行交互,實現(xiàn)類似 NMS 的效果;
4. 最后的 prediction heads 基于每個 query 在 decoder 中提取到的特征,預(yù)測出物體的 bounding box 的位置和類別。然而,DETR 的訓(xùn)練收斂速度非常慢,要訓(xùn)練 500 epochs 才能達到比較好的性能。
下圖是對 DETR 的 decoder cross-attention 中 attention map 的可視化。我們可以看到,DETR decoder cross-attention 里的 query 查詢到的區(qū)域都是物體的 extremity 區(qū)域,比如左圖中大象的鼻子、后背、腳掌。通過這些關(guān)鍵區(qū)域,我們能夠準(zhǔn)確地定位物體的位置,識別出物體的類別。
1.2 DETR 收斂慢的原因?
為了分析 DETR 為什么收斂慢,我們對 DETR decoder cross-attention 中的 spatial attention map 進行了可視化。下圖中第一行是我們的 Conditional DETR 的結(jié)果,第二行是 DETR 訓(xùn)練 50 epoch 的結(jié)果,第三行是 DETR 訓(xùn)練 500 epoch 的結(jié)果。由于 DETR 使用了 multi-head attention,這里的每一列對應(yīng)了一個 head。
?
我們可以看到,每個 head 的 spatial attention map 都在嘗試找物體的一個 extremity 區(qū)域,例如:圍繞物體的 bounding box 的某條邊。訓(xùn)練了 500 epoch 的 DETR 基本能夠找準(zhǔn) extremity 區(qū)域的大概位置,然而只訓(xùn)練了 50 epoch 的 DETR 卻找不準(zhǔn)。
我們認(rèn)為,DETR 在計算 cross-attention 時,query 中的 content embedding 要同時和 key 中的 content embedding 以及 key 中的 spatial embedding 做匹配,這就對 content embedding 的質(zhì)量要求非常高。
而訓(xùn)練了 50 epoch 的DETR,因為 content embedding 質(zhì)量不高,無法準(zhǔn)確地縮小搜尋物體的范圍,導(dǎo)致收斂緩慢。所以用一句話總結(jié) DETR 收斂慢的原因,就是 DETR 高度依賴高質(zhì)量的 content embedding 去定位物體的 extremity 區(qū)域,而這部分區(qū)域恰恰是定位和識別物體的關(guān)鍵。?
為了解決對高質(zhì)量 content embedding 的依賴,我們將 DETR decoder cross-attention 的功能進行解耦,并提出 conditional spatial embedding。Content embedding 只負(fù)責(zé)根據(jù)外觀去搜尋跟物體相關(guān)的區(qū)域,而不用考慮跟 spatial embedding 的匹配; 對于 spatial 部分,conditional spatial embedding 可以顯式地定位物體的 extremity 區(qū)域,縮小搜索物體的范圍。
Conditional DETR
2.1 Overview
我們的方法沿用了 DETR 的整體流程,包括 CNN backbone,transformer encoder,transformer decoder, 以及 object class 和 box 位置的預(yù)測器。Encoder 和 decoder 各自由6個相同的 layer 堆疊而成。我們相對于 DETR 的改動主要在 cross-attention 部分。?
2.1.1 Box Regression?
我們從每個 decoder embedding (一個 object query 會對應(yīng)一個 decoder embedding)預(yù)測一個候選框:
?
這里, 是decoder embedding, 是 4 維向量:,前兩維是 box 的中心,后兩維是長和寬。sigmoid 函數(shù)用來將預(yù)測的向量處理到 [0, 1] 區(qū)間,表示相對于這個圖像的位置/相對于圖像長寬的大小。FFN 用來預(yù)測 unnormalized box, 是從 reference point 產(chǎn)生的 unnormalized 2D 坐標(biāo)。Reference point 是從 object query 預(yù)測出的一個坐標(biāo),大概估計了這個 query 負(fù)責(zé)的區(qū)域范圍。在原始 DETR 中沒有 reference point 的概念,因此它的 是 (0,0)。這里 也可以直接作為一個模型參數(shù)來學(xué)習(xí),而非從 reference point 預(yù)測,我們的實驗發(fā)現(xiàn)效果僅僅略微差一些。
2.1.2?Category prediction
我們使用 FFN 預(yù)測每個候選框的類別:
2.2 DETR Cross-Attention
DETR 的 cross-attention 有三個輸入:query, key, value。Query 由來自 decoder 中 self-attention 的輸出 (content query: ) 和所有圖片共享的 object query (spatial query: , 在 DETR 中其實就是 object query ) 相加得到。Key 由來自 encoder 的輸出 (content key: ) 和對于 2D 坐標(biāo)的位置編碼 (spatial key: ) 相加得到。Value 的組成和 key 相同。
在這里,content 代表這個向量的內(nèi)容和圖像 (顏色、紋理等) 是相關(guān)的,而 spatial 代表這個向量它更多包含空間上的信息,他的內(nèi)容和圖像的內(nèi)容無關(guān)。Attention 模塊的輸出,就是對 query 和 key 算一次內(nèi)積得到注意力的權(quán)重,用這個權(quán)重給 value 進行加權(quán)。我們將這個過程寫成下面的形式:
2.3 Conditional Cross-Attention
我們對 DETR 的 cross-attention 中 query 和 key 的形式做了些改變。Query 由 content query 和 spatial query concat 而成,key 由 content key 和 spatial key concat 而成。這樣 query 和 key 做內(nèi)積,得到如下結(jié)果:
這里只有兩項,第一項計算 content 相似度,第二項計算 spatial 相似度。
2.3.1 Conditional spatial query prediction
上文提到,我們是基于 (1) 當(dāng)前 layer 的 decoder embedding 中包含的信息,以及 (2) reference point 一起預(yù)測 box 信息的。這也就是說, decoder embedding 中包含了 box 有關(guān)的區(qū)域 (比如box的四條邊、或者box內(nèi)部的點)到 reference point 的偏移量。因此,我們在生成 conditional spatial query 的時候,也要同時考慮 reference point s 和 decoder embedding f:
和 box prediction 類似,我們的 也由兩部分組成,一個 reference,一個“偏移量”。因為這里的 reference 在一個高維位置編碼空間中,所以“偏移量”也不再是 xy 方向的值,而是一個施加在高維向量上的 projection 函數(shù)。
首先,我們將該 query 對應(yīng)的 reference point 的 2D 坐標(biāo)歸一化之后映射到和 spatial key 相同的正弦位置編碼空間中,得到 reference :
然后,我們將 decoder embedding 中包含的偏移量信息通過一個 FFN (linear + ReLU + linear) 映射到高維空間中,得到針對 的“偏移量”:
那么,最終的 conditional spatial query 就可以由 reference 和偏移量組合得到:。對于 我們選擇一種計算上較為簡單的設(shè)計:對角矩陣。假設(shè) 所處的空間是 256-d 的,那么對角矩陣的對角線上的 256 個元素可以記為向量 。那么 conditional spatial query 可以通過 element-wise multiplication 得到:
2.3.2?Multi-head cross-attention
和 DETR 一樣,我們在 cross-attention 中使用 multi-head 的設(shè)計。對于同一個 query,我們使用 8 個 head,即將 query/key 通過 linear projection 映射到 8 個維度更低的 sub-query/sub-key。通過這 8 個 head 各自計算出的 conditional spatial sub-query,我們可以得到關(guān)于一個物體的位置的不同角度的表達:bounding box 的四條邊,或者 bounding box 的內(nèi)部。這個我們在下面的可視化部分展示一下。
2.4 Visualization and Analysis
在這個圖中,我們可視化了同一個 query 不同 head 的 attention map。左右兩側(cè)是兩個樣例,一個是同類別只有一個個體的情況,另一個是同類別多個體的情況。圖中的高亮部分是 attention map 權(quán)重較高的區(qū)域。
1. 第一行是 spatial attention map: ,第二行是 content attention map: ,第三行是組合之后的 attention map: 。
2. 每一列表示一個 head。我們只畫了 8 個 head 中的 5 個,其余 3 個 head 對應(yīng)的區(qū)域和這 5 個有重疊,所以沒有進行展示。?
從圖中,我們可以得出結(jié)論:
1. 每個 head 的 spatial attention map 對應(yīng)了跟 box 有關(guān)的一個區(qū)域。有趣的是,有些 head 對應(yīng)的區(qū)域甚至跟 bounding box 的幾條邊重合了,分別對應(yīng)了上、下、左、右四條邊。另外一個對應(yīng)了物體內(nèi)部的一小塊區(qū)域,這個區(qū)域的特征經(jīng)過 transformer encoder 的處理,或許已經(jīng)足夠主要作用是用來對物體進行識別和分類。
2. 每個 head 的 content attention map 對應(yīng)了跟物體外觀相似的一些區(qū)域 (甚至是同類別的其他個體)。我們從右邊的例子可以看出來,想檢測小牛,但是 content attention 很多都聚焦到大牛的身上,這顯然是不利于檢測的。
3. 當(dāng)我們將 content 和 spatial attention map 進行組合,我們發(fā)現(xiàn)當(dāng)前物體以外的區(qū)域被完美地過濾掉了,剩下的高亮區(qū)域基本存在于物體的一些 extremity 區(qū)域,比如右側(cè)樣例中小牛的頭上、腳上這些跟 bounding box 有重合的區(qū)域。
2.4.1 對可視化的一些分析
根據(jù)上面的可視化結(jié)果,我們對 conditional spatial query 的作用做了分析。它的作用有兩方面:
1)將spatial attention map 的高亮區(qū)域映射到物體的四條邊界上和中心區(qū)域。有趣的是,對于不同的物體,同一個 head 的這些高亮區(qū)域相對于 bounding box 的位置是類似的;
2)可以根據(jù)物體的大小將 spatial attention map 高亮的區(qū)域做縮放:對于大物體,有更大的 spread 范圍,對于小物體則有更小的 spread 范圍。這些作用都?xì)w功于之前提到的作用于 reference 的變換 。
實驗
3.1 數(shù)據(jù)集介紹
我們在 COCO 2017 Detection dataset 上進行實驗,該數(shù)據(jù)集包括 118K 圖像的訓(xùn)練集和 5K 圖像的驗證集。
3.2 和 DETR 的性能對比
從表中我們可以看到:
1. DETR 50 epoch 的模型比 500 epoch 的模型差很多。
2. 當(dāng)使用 ResNet-50/ResNet-101 作為 backbone 時,Conditional DETR 訓(xùn)練 50 epoch 的模型比 DETR 訓(xùn)練 500 epoch 的模型稍差一些;當(dāng)使用 DC5-ResNet-50/DC5-ResNet-101 作為 backbone 時,Conditional DETR 訓(xùn)練 50 epoch 可以達到與 DETR 訓(xùn)練 500 epoch 差不多/更高的結(jié)果。當(dāng) Conditional DETR 訓(xùn)練 75 epoch 及以上,4 種不同的 backbone 都可以超過 DETR 訓(xùn)練 500 epoch 的結(jié)果。這也說明了在更強的backbone下,Conditional DETR 相對于 DETR 能表現(xiàn)得更好。?
3. DC5-ResNet backbone 下,Conditional DETR 可以比 DETR 的收斂速度快 10倍;ResNet backbone 下,Conditional DETR 可以比 DETR 的收斂速度快 6.67 倍。?
除此之外,我們在 single-scale 的條件下,還跟 Deformable DETR 以及 UP-DETR 進行對比。在 ResNet-50/DC5-ResNet-50 backbone下,我們的方法都超過了 Deformable DETR-SS。盡管他們的計算量、參數(shù)量不同,仍然說明了 Conditional DETR 是很有效的。當(dāng)與 UP-DETR 比較,我們的方法用更少的 epoch 獲得了更高的結(jié)果。
3.3 和多尺度/高分辨率下的 DETR 的變種算法的對比
Conditional DETR 的目的是加速 DETR 的訓(xùn)練過程,所以并沒有處理 encoder 帶來的大量計算量的問題。因此,我們并沒有期望 Conditional DETR 能夠達到其他使用 8x 的分辨率/多尺度的方法相近的結(jié)果。?
然而,我們發(fā)現(xiàn)在 DC5-R50 的 backbone 下,我們的方法居然和 Deformable DETR 表現(xiàn)的一樣好,均達到了 43.8 的 AP。值得一提的是,只使用 single scale 的模型 Deformable DETR-DC5-R50-SS 僅有 41.5 的 AP,說明他們的算法很大程度上受益于 multi-scale 的設(shè)計。?
我們的方法也取得了跟 TSP-FCOS/TSP-RCNN 持平的結(jié)果。他們的方法是對 FCOS/Faster FCNN 的擴展。沒有使用 transformer decoder,而是將 transformer encoder 放在少量的選定的位置之后 (在 Faster RCNN 中他們用的 region proposal),這使得他們在 self-attention 部分的計算量大幅減小。
總結(jié)
在這篇論文中,為了加速 DETR 的收斂速度,我們提出一個簡單而有效的 conditional cross-attention 機制。該機制的核心是從 decoder embedding 和 reference point 中學(xué)習(xí)到一個 conditional spatial query。這個 query 可以顯式地去找物體的 extremity 區(qū)域,從而縮小了搜索物體的范圍,幫助物體的定位,緩解了 DETR 訓(xùn)練中對于 content embedding 過度依賴的問題。
?
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識的人。
總有一些你不認(rèn)識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點剖析、科研心得或競賽經(jīng)驗講解等。我們的目的只有一個,讓知識真正流動起來。
?????稿件基本要求:
? 文章確系個人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺已發(fā)表或待發(fā)表的文章,請明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競爭力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
?????投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯(lián)系方式(微信),以便我們在稿件選用的第一時間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
????
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的ICCV 2021 | 通过显式寻找物体的extremity区域加快DETR的收敛的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: fomc和美联储什么关系 联合会议可以决
- 下一篇: 荷兰用什么货币