48小时单GPU训练DistilBERT!这个检索模型轻松达到SOTA
?PaperWeekly 原創 ·?作者 | Maple小七
單位 | 北京郵電大學
研究方向 | 自然語言處理
論文標題:?
Efficiently Teaching an Effective Dense Retriever with Balanced Topic Aware Sampling
收錄會議:
SIGIR 2021
論文鏈接:
https://arxiv.org/abs/2104.06967
代碼鏈接:
https://github.com/sebastian-hofstaetter/tas-balanced-dense-retrieval
基于 BERT 的稠密檢索模型雖然在 IR 領域取得了階段性的成功,但檢索模型的訓練、索引和查詢效率一直是 IR 社區關注的重點問題,雖然超越 SOTA 的檢索模型越來越多,但模型的訓練成本也越來越大,以至于要訓練最先進的稠密檢索模型通常都需要 8×V100 的配置。而采用本文提出的 TAS-Balanced 和 Dual-supervision 訓練策略,我們僅需要在單個消費級 GPU 上花費 48 小時從頭訓練一個 6 層的 DistilBERT 就能取得 SOTA 結果,這再一次證明了當前大部分稠密檢索模型的訓練是緩慢且低效的。
緒言
在短短的兩年時間內,當初被質疑是 Neural Hype 的 Neural IR 現在已經被 IR 社區廣泛接受,不少開源搜索引擎也逐漸支持了基于 BERT 的稠密檢索(dense retrieval),基本達到了開箱即用的效果。其中,DPR 提出的 是當前最主流的稠密檢索模型,然而眾所周知的是, 的可遷移性遠不如 BM25 這類 learning-free 的傳統檢索方法,想要在具體的業務場景下使用 并取得理想的結果,我們通常需要準備充足的標注數據進一步訓練檢索模型。
因此,如何高效地訓練一個又快又好的 一直是 Neural IR 的研究熱點。目前來看,改進 主要有兩條路線可走,其中一條路線是改變 batch 內的樣本組合,讓模型能夠獲取更豐富的對比信息:
優化模型的訓練過程:這類方法的代表作是 ANCE 提出的動態負采樣策略,其基本思路是在訓練過程中定期刷新索引,從而為模型提供更優質的難負樣本,而不是像 DPR 那樣僅從 BM25 中獲取負樣本。在此基礎上,LTRe 指出目前的檢索模型其實是按 learning to rank 來訓練的,因為訓練過程中模型僅能看到一個 batch 內的樣本,但如果我們只訓練 query encoder,凍結 passage embedding,我們就可以按照 learning to retrieve 的方式計算全局損失,而不是僅計算一個 batch 的損失。除此之外,RocketQA 提出了 Cross Batch 技巧來增大 batch size,由于檢索模型采用對比損失訓練,因此理論上增大 batch size 帶來的基本都是正收益。
然而,這三種策略都在原始的 的基礎上增加了額外的計算成本,并且實現都比較復雜。除此之外,我們也可以利用知識蒸餾(knowledge distillation)為模型提供更優質的監督信號:
優化模型的監督信號:?我們可以將表達能力更強但運行效率更低的 或 當作 teacher model 來為 提供 soft label。在檢索模型的訓練中,知識蒸餾的損失函數有很多可能的選擇,本文僅討論 pairwise loss 和 in-batch negative loss,其中 in-batch negative loss 在 pairwise loss 的基礎上將 batch 內部其他 query 的負樣本也當作當前 query 的負樣本,這兩類蒸餾 loss 的詳細定義后文會講。
本文同樣是在上述兩個方面對 做出優化,在訓練過程方面,作者提出了 Balanced Topic Aware Sampling(TAS-Balanced)策略來構建 batch 內的訓練樣本;在監督信號方面,作者提出了將 pairwise loss 和 in-batch negative loss 結合的 dual-supervision 蒸餾方式。
Dual Supervision
越來越多的證據表明知識蒸餾能夠帶來稠密檢索模型性能的提升,本文將 提供的 pairwise loss 和 提供的 in-batch negative loss 結合起來為 提供監督信號,下面先簡單介紹一下 teacher model 和 student model。
Teacher Model:、?
是當前應用最為廣泛的排序模型,它簡單地將 query 和 passage 的拼接作為 的輸入序列,然后對 輸出向量做一個線性變換得到相關性打分:
是一個經典的多向量表示模型,它將 query 和 passage 之間的交互簡化為 max-sum 來克服 無法緩存 passage 向量的問題,其基本思路是首先對 query 和 passage 分別編碼
然后計算每個 query term 和每個 passage term 的點積相似度,按 doc term 做 max-pooling 并按 query term 求和獲取 query 和 passage 的相似度:
雖然理論上 可以對 passage 建立離線索引,但存儲 passage 多向量表示的資源開銷是非常大的,并且該存儲成本隨著語料庫的 term 數量呈線性增長,再加上 max-sum 的操作也會帶來額外的計算成本,因此這里我們將 當作 的 teacher。
Student Model:?
DPR 提出的 僅使用二元標簽和 BM25 生成的負樣本訓練模型, 首先將 query 和 passage 獨立編碼為單個向量:
然后計算 和 的點積相似度:
在檢索階段, 首先對 query 編碼,然后利用 faiss 做最大內積檢索,下表展示了在單個消費級 GPU 上 6 層 DistilBERT 在 800 萬 passage 集合上的檢索速度。
2.1 Dual-Teacher Supervision
如果僅看監督信號的質量, 提供的 in-batch negative loss 當然是最優質的。然而, 雖然在表達能力上比 更強,但它實際上很少用于計算 in-batch negative loss,因為 需要單獨編碼每個 query-passage 樣本對,所以其計算開銷隨著 batch size 二次增長,而 解耦了 query 和 passage 的表示,因此它的開銷是隨著 batch size 線性增長的,其 in-batch negative loss 的計算效率要高得多。
因此這里我們只讓 提供 pairwise loss,具體來說,我們首先利用訓練好的 對訓練集中所有的 query-passage 樣本對打分,然后計算 的蒸餾損失,蒸餾損失的具體形式有很多選擇,這里作者選擇了 Margin-MSE loss 作為 pairwise loss:
其中 和 分別為 和 。
我們同時讓 提供 in-batch negative loss:
in-batch negative loss 中的 其實也可以替換成別的 loss,作者在后續實驗中也嘗試了一些看起來更有效的 listwise loss,然而實驗結果表明 Margin-MSE loss 依舊是最佳的選擇。因此,作者最終提出的蒸餾 loss 是 pairwise loss 和 in-batch negative loss 的加權平均,在后續實驗中,作者設加權系數 :
Balanced Topic Aware Sampling
在原始的 的訓練中,我們首先隨機地從 query 集合 中采樣 個 ,然后再為每個 隨機采樣一個正樣本 和一個負樣本 組成一個 batch:
其中 表示從集合 無放回地采樣 個樣本。由于訓練集是非常大的,每個 batch 中的 幾乎都是沒有相關性的,但是當我們計算 in-batch negative loss 時,query 不僅和自身的 交互,也和別的 query 對應的 交互,然而,由于 對模型來說大概率是簡單樣本,因此它所能提供的信息增益是非常少的,這也導致了每個 batch 所能提供的信息量偏少,使得檢索模型需要長時間的訓練才能收斂。
3.1?TAS
針對這個問題,作者提出了 Topic Aware Sampling(TAS)策略來構建 batch 內的訓練樣本,具體來說,在訓練之前,我們先利用 k-means 算法將 query 聚類到 k 個 cluster 中:
其中 query 的表示 由基線模型 提供, 為 的聚類中心,這樣,每個 cluster 中的 query 都是主題相關的,在構建 batch 的時候,我們可以先從 cluster 的集合 中隨機抽樣 個 cluster,然后在每個 cluster 上隨機抽樣 個 query:
在后續的實驗中,作者為 40 萬個 query 創建了?k=2000?個 cluster,并設 batch size 大小為 b=32,組建 batch 時隨機抽樣的 cluster 數量為 n=1,這樣,每個 batch 中的樣本都來自于同一個 cluster。如下圖所示,相比于在整個 query 集合上隨機抽樣,TAS 策略生成的 batch 內部的 query 有更高的主題相似性。
3.2?TAS-balanced
在組建 batch 的時候,我們還需要為每個采樣到的 query 配置正負樣本對 。不難想到,幾乎所有 query 對應的 都比 少得多,如果用獨立隨機抽樣的方式獲取 和 ,那么組成的 的 margin(也就是 )大概率是很大的,因此大部分 對模型來說是簡單樣本,因為模型很容易將 和 分開。
因此,我們可以在 TAS 策略的基礎上進一步均衡 batch 內正負樣本對的 margin 分布以減少 high margin(low information)的正負樣本對。具體來說,針對每個 query,我們首先計算它對應的樣本對集合的最小 margin 和最大 margin,然后將該區間分割為 個子區間,在為 query 配置 時,我們首先從這 個子區間中隨機選擇一個子區間,然后從 margin 落在該子區間內的 集合中隨機采樣并組成一個訓練樣本:
這樣,在構建一個 batch 的時候,我們首先需要采樣一個 cluster,然后采樣 b?個 query,接下來為每個 query 采樣一個 margin 子區間,最后在該子區間上采樣一個正負樣本對,這整套流程就是所謂的 TAS-balanced batch sampling:
需要注意的是,TAS-balanced 策略不會影響模型的訓練速度,因為 batch 的構建是可以并行處理或者預先處理好的。TAS-balanced 策略組建的 batch 對模型來說整體的難度更大,因此為模型提供了更多的信息量,即使采用較小的 batch size,模型也能很好地收斂。如下表所示,我們可以在消費級顯卡上(11GB 內存)高效地訓練 而不需要昂貴的 8×V100 的配置,因為該方法不需要像 ANCE 那樣重復刷新索引,也不需要像 RocketQA 那樣進行超大批量的訓練。
3.3 Experiment
作者選擇 MSMACRO-Passage 官方提供的 4000 萬正負樣本對作為檢索模型的訓練集,并選擇 MSMACRO-DEV(sparsely-judged,包含 6980 個 query)和 TREC-DL 19/20(densely-judged,包含 43/54 個 query)作為驗證集。同時 和 ?均采用 6 層的 DistilBERT 初始化,且沒有使用預訓練的檢索模型。
Results
4.1 Source of Effectiveness
首先我們對作者提出的 Dual-supervision 做消融實驗,如下表所示。對于基于 pairwise loss 的知識蒸餾,Margin-MSE loss 的優越性已經被之前的論文證明,所以這里僅討論 in-batch negative loss 的有效性。作者對比了基于 listwise loss 的 KL Divergence、ListNet 和 Lambdarank,實驗結果表明這些損失的效果都不如 Margin-MSE loss,尤其是在 R@1K 上面。
為什么 pairwise 的 Margin-MSE 比 listwise loss 更好呢?因為 Margin-MSE 不僅僅是讓模型去學習 teacher 所給出的排序,同時還學習 teacher score 的分布,由于 batch 內部樣本的 order 實際上是有偏的,它并不能準確刻畫樣本間的真實距離,因此比起學習 order,學習 score 分布其實是一種更精確的方式。另外,由于 teacher 和 student 在訓練階段所使用的損失是一致的,這也會讓 student 更容易學習到 teacher 的 score 分布。
接下來我們對 TAS-Balanced 策略做消融實驗,如下表所示。總體來說,TAS-balanced 策略加上 Dual-supervision 蒸餾可以在各個數據集上取得最優性能。值得關注的是,在單獨的 pairwise loss 的監督下使用 TAS 策略其實并不能帶來明顯的提升,這是因為 TAS 是面向 in-batch negative loss 設計的,使用 pairwise loss 訓練時,batch 內的樣本是沒有交互的,因此 TAS 也就不會起作用。而 TAS-balanced 策略會影響正負樣本對的組成方式,因此會對 pairwise loss 產生一定的影響。
4.2 Comparing to Baselines
下表對比了作者的模型和其他模型的性能,對比最后三行,我們可以發現一個有趣的現象:增大 batch size 在 TREC-DL 這類 densely-judge 的數據集上沒有帶來提升,但在 MSMACRO-DEV 這類 sparsely-judge 的數據集上會帶來持續的提升。?因此作者猜想增大 batch size 會導致模型在 sparsely-judge 的 MSMACRO 上過擬合,RocketQA 的 SOTA 表現可能僅僅是因為它的 batch size 夠大。
4.3 TAS-Balanced Retrieval in a Pipeline
為了進一步證明方法的有效性,作者嘗試將 TAS-Balance 訓練的檢索模型應用到召回-排序系統中。眾所周知,稠密檢索和稀疏檢索是互補的,且融合稀疏檢索幾乎不會影響召回速度,因此作者考慮將稀疏檢索的 docT5query 的檢索結果和 TAS-balanced 稠密檢索模型的結果融合,然后使用最先進的 mono-duo-T5 排序模型對檢索結果做重排。
選擇不同的召回模型、排序模型和不同大小的候選集,我們可以得到不同延遲水平的檢索系統。如上表所示,作者提出的模型在各個延遲水平上均取得了優異的表現。值得注意的是,在高延遲系統中,排序模型 mono-duo-T5 是在 BM25 的召回結果上訓練的,這實際上會導致訓練測試分布不一致的問題,所以 TAS-B+mono-duo-T5 甚至沒能超越 BM25+mono-duo-T5,為了取得更好的性能,我們應該先訓召回模型,然后在召回模型的給出召回結果上訓練排序模型,這其實也間接反映了當前的排序模型泛化性不足的問題。
Discussion
本篇論文最大的亮點是 TAS-Balanced 策略的高效性,使用作者的模型,我們僅需要在單個消費級 GPU 上從頭訓練 48 小時就能取得 SOTA 結果,極大地降低了檢索模型的訓練成本,這在之前是無法想象的。實際上,比起 NLP 社區,IR 社區更加強調模型和數據的 Efficiency,這一課題在將來也一定會受到持續的關注。
特別鳴謝
感謝 TCCI 天橋腦科學研究院對于 PaperWeekly 的支持。TCCI 關注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
·
總結
以上是生活随笔為你收集整理的48小时单GPU训练DistilBERT!这个检索模型轻松达到SOTA的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2021胡润全球富豪榜 钟睒睒成全球前十
- 下一篇: 办不了信用卡是什么原因 办信用卡流程分享