ICLR 2020 | 可提速3000倍的全新信息匹配架构(附代码复现)
?PaperWeekly 原創(chuàng) ·?作者|周樹帆
學(xué)校|上海交通大學(xué)碩士生
研究方向|自然語(yǔ)言處理
今天聊一篇 FAIR 發(fā)表在 ICLR 2020 上的文章:Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring。
論文標(biāo)題:Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring
論文來(lái)源:ICLR 2020
論文鏈接:https://arxiv.org/abs/1905.01969
和一些花里胡哨但是沒(méi)有卵用的論文不同,這篇文章可謂大道至簡(jiǎn)。該文用一種非常簡(jiǎn)單但是有效的方式同時(shí)解決了 DSSM 式的 Bi-encoder 匹配質(zhì)量低的問(wèn)題和 ARC-II、BERT 等交互式的 Cross-encoder 匹配速度慢的問(wèn)題。
背景
眾所周知,常見(jiàn)的搜索、檢索式問(wèn)答、自然語(yǔ)言推斷等任務(wù),它們本質(zhì)上都是一種相關(guān)性匹配任務(wù):給定一段文本作為 query,然后匹配出最為相關(guān)的文檔或答案然后返回給用戶。
目前主流的文本相關(guān)性匹配架構(gòu)有兩大類:以 DSSM 為代表的 Siamese Network 架構(gòu)、以及形如 ARC-II、ABCNN 或 BERT(基于 Self-Attention)的交互式匹配架構(gòu)。
1.1 Siamese Network
如圖 1 所示,Siamese Network 式(本篇文章又稱其為 Bi-encoder)的匹配方案會(huì)利用 2 個(gè)網(wǎng)絡(luò)分別將 query 和 candidates 編碼成??和??,最后再通過(guò)一個(gè)相關(guān)性判別函數(shù)(通常為 cosine)計(jì)算兩個(gè) vec 之間的相似度。
這種方案的最大特點(diǎn)就是 query 和 candidates 直到最后的相關(guān)性判別函數(shù)時(shí)才發(fā)生交互,所以會(huì)對(duì)模型的匹配性能產(chǎn)生一定的影響。
但是這種完全獨(dú)立的編碼方式使得我們可以離線計(jì)算好所有 candidates 的向量,線上運(yùn)行時(shí)只需計(jì)算 query 的向量然后匹配已有向量即可??偟膩?lái)說(shuō),這種方案匹配速度極快,但是匹配質(zhì)量不能達(dá)到最佳。
▲?圖1. Siamese Network(本篇論文又稱其為Bi-encoder)
1.2 交互式匹配
如圖 2 所示,交互式匹配(本文記作 Cross-encoder)的核心思想是則是 query 和 candidates 時(shí)時(shí)刻刻都應(yīng)相互感知,相互交融,從而更深刻地感受到相互之間是否足夠匹配。
早期的交互方案如 ARC-II、ABCNN 等會(huì)計(jì)算??和??之間的word embedding相似度、Q、C 分別過(guò)? RNN 之后的??、??之間的相似度,最后再用一些 CNN 之類的方法整合結(jié)果,然后用 MLP 做二分類判別是相關(guān)還是不相關(guān)。
▲ 圖2. 交互式匹配示意圖(圖中為ARC-II)
另外在 BERT 興起之后,如圖 3 所示般將 query 和 candidate 拼成一句話,然后利用 self-attention 完成 query 和 candidate 之間的交互的模型也大量涌現(xiàn),并且取得了非常顯著的成果。本篇論文實(shí)現(xiàn)的 Cross-encoder 也是基于圖 3 的架構(gòu)。
相較于 Siamese Network,這類交互式匹配方案可以在 Q 和 C 之間實(shí)現(xiàn)更細(xì)粒度的匹配,所以通??梢匀〉酶玫钠ヅ湫Ч?。
但是很顯然,這類方案無(wú)法離線計(jì)算 candidates的表征向量,每處理一個(gè) query 都只能遍歷所有 (query, candidate) 的 pairs 依次計(jì)算相關(guān)性,所以這類方案相當(dāng)耗時(shí)(當(dāng)然也有很多提速手段,不過(guò)那不是本文的重點(diǎn))。
▲ 圖3. Cross-encoder
Poly-Encoder
Bi-encoder (Siamese Network) 和 Cross-encoder(交互式網(wǎng)絡(luò))都有各自顯著的優(yōu)點(diǎn)和缺點(diǎn),而本文提出的 Poly-encoder 架構(gòu)同時(shí)集成了兩類方案的優(yōu)點(diǎn)并避免了缺點(diǎn)。
▲ 圖4. Poly-encoder
Poly-encoder 如圖 4 所示。Poly-encoder 的思想非常簡(jiǎn)單(簡(jiǎn)單到論文里僅用了 2 段文字),按我的個(gè)人理解描述:
Bi-encoder 的主要問(wèn)題在于它要求 encoder 將 query 的所有信息都塞進(jìn)一個(gè)固定的比較 general 的向量中,這導(dǎo)致最后??和??計(jì)算相似度時(shí)已經(jīng)為時(shí)過(guò)晚,很多細(xì)粒度的信息丟失了(e.g. query 為“我要買蘋果”),所以無(wú)法完成更精準(zhǔn)的匹配。
這就有點(diǎn)像 word2vec 靜態(tài)詞向量:即使一個(gè)詞有多種語(yǔ)義,它的所有語(yǔ)義也不得不塞進(jìn)一個(gè)固定的詞向量。
為了克服這個(gè)問(wèn)題,Poly-encoder 的方案就是每個(gè) query 產(chǎn)生 m 個(gè)不同的??,接著再根據(jù)??動(dòng)態(tài)地將 m 個(gè)??集成為最終的??(其實(shí)有點(diǎn)像封面圖那樣,有一點(diǎn)用 m 個(gè)向量組合出最終的 Low Poly(baike.baidu.com/item/Lo)化向量的味道),最后再計(jì)算??和??的匹配度。
用論文里的話來(lái)說(shuō):
論文中的 ctxt 指代 context,相當(dāng)于 query;cand 指代 candidate。上面這段論文建議我們可以隨機(jī)初始化 m 個(gè)通過(guò) dot product 計(jì)算 attention,從而將長(zhǎng)度為 N 的 context 編碼成 m 個(gè)向量??(即??)。
接著:
我們?cè)儆?candidate 對(duì)應(yīng)的向量??計(jì)算 m 個(gè)??的 attention,進(jìn)而得到最終的?。
很顯然,Poly-encoder 架構(gòu)在實(shí)際部署時(shí)是可以離線計(jì)算好所有 candidates 的向量的,所以只需要計(jì)算 query 對(duì)應(yīng)的 m 個(gè)??向量,再通過(guò)簡(jiǎn)單的 dot product 就可以快速計(jì)算好對(duì)應(yīng)每個(gè) candidate 的“動(dòng)態(tài)的”?向量。
看起來(lái) Poly-encoder 享有 Bi-encoder 的速度,同時(shí)又有實(shí)現(xiàn)更精準(zhǔn)匹配的潛力。我們通過(guò)實(shí)驗(yàn)來(lái)一探究竟。
實(shí)驗(yàn)
本文選擇了檢索式對(duì)話數(shù)據(jù)集 ConvAI2、DSTC 7、Ubuntu v2 數(shù)據(jù)集以及 Wikipedia IR 數(shù)據(jù)集進(jìn)行實(shí)驗(yàn)。
訓(xùn)練 Bi-encoder 和 Poly-encoder 時(shí)由于這兩類模型的特性,負(fù)采樣方式為:在訓(xùn)練過(guò)程中,使用同一個(gè) batch 中的其他 query 對(duì)應(yīng)的 response 作為負(fù)樣本(如果難以理解,可以稍后結(jié)合復(fù)現(xiàn)代碼來(lái)理解)。
而 Cross-encoder 的負(fù)采樣方式為:在開(kāi)始訓(xùn)練之前,隨機(jī)采樣 15 個(gè) responses 作為負(fù)樣本。
3.1 檢索質(zhì)量
圖5給出了一些 baseline 模型以及本文的基于預(yù)訓(xùn)練 BERT 的 Bi-encoder、Poly-encoder 以及 Cross-encoder 在各個(gè)數(shù)據(jù)集上的表現(xiàn)。
當(dāng)然我們很容易發(fā)現(xiàn),本文的所有模型由于以預(yù)訓(xùn)練的BERT為基礎(chǔ),他們的表現(xiàn)都要顯著超出不使用 BERT的那些 baseline 們。所以我們只需要關(guān)注 Bi、Poly 和 Cross 三種架構(gòu)之間的表現(xiàn)差異即可。
實(shí)驗(yàn)結(jié)果表明即使僅增設(shè)少數(shù)幾個(gè) code(用于計(jì)算 attention 產(chǎn)生向量),Poly-encoder 的表現(xiàn)也要遠(yuǎn)優(yōu)于 Bi-encoder。
實(shí)驗(yàn)結(jié)果還表明,Poly-encoder 的表現(xiàn)會(huì)隨著 code 個(gè)數(shù)的增加而逐漸增加,并且慢慢逼近 Cross-encoder 的結(jié)果(個(gè)人認(rèn)為 Cross-encoder 的表現(xiàn)應(yīng)該是 Poly-encoder 的上界,不過(guò)偶爾也可能會(huì)因?yàn)橐恍┡既灰蛩貙?dǎo)致 Poly-encoder 反超 Cross-encoder 的情況)。
另外,為了體現(xiàn) Cross-encoder 在速度上的局限性,作者還很有意思地跳過(guò)了 Cross-encoder 在 Wikipedia IR 上的測(cè)評(píng)并寫到:“In addition, Cross-encoders are also too slow to evaluate on the evaluation setup of that task, which has 10k candidates”。
▲ 圖5. 模型表現(xiàn)匯總
3.2 檢索速度
圖 5 的實(shí)驗(yàn)結(jié)果已經(jīng)表明 Poly-encoder 的檢索質(zhì)量明顯優(yōu)于 Bi-encoder 架構(gòu),且能逼近 Cross-encoder 架構(gòu)的效果。剩下的關(guān)鍵問(wèn)題就是 Poly-encoder 是否會(huì)顯著增加檢索耗時(shí)?
圖 6 給出了各模型在 ConvAI2 數(shù)據(jù)集上的檢索耗時(shí)。
Bi-encoder 理所當(dāng)然是最快的架構(gòu),當(dāng) candidates 為 100k 時(shí),在 CPU 和 GPU 環(huán)境下其檢索耗時(shí)分別為 160ms 和 22ms;而 Cross-encoder 顯然是最慢的一個(gè):同樣實(shí)驗(yàn)條件下其檢索耗時(shí)分別約為 2.2M (220 萬(wàn)) ms 和 266K (26.6 萬(wàn)) ms。
反觀 Poly-encoder,以 Poly-encoder 360 為例,該模型可以達(dá)到遠(yuǎn)超 Bi-encoder、接近甚至反超 Cross-encoder 的檢索質(zhì)量,但其檢索速度確比 Cross-encoder 足足快了約 2600-3000 倍!
▲ 圖6. 各模型在ConvAI2數(shù)據(jù)集上的檢索耗時(shí)
論文小結(jié)
總的來(lái)說(shuō),本文的出發(fā)點(diǎn)就是希望找到一個(gè)速度快但質(zhì)量不足的 Bi-encoder 架構(gòu)和質(zhì)量高但速度慢的 Cross-encoder 架構(gòu)的折中。
本文提出的 Poly-encoder 的核心思想雖然非常簡(jiǎn)單,但是卻十分有效(親測(cè)),確實(shí)在很多場(chǎng)景下可以作為 Bi-encoder 的替代,甚至在一些對(duì)速度要求較高的場(chǎng)景下可以作為 Cross-encoder 的替代。
方案簡(jiǎn)潔固然是本文的一大優(yōu)點(diǎn),不過(guò)這也給未來(lái)的研究留下了空間。相信未來(lái)很快就會(huì)有許多基于 Poly-encoder 的改進(jìn)版出現(xiàn)。
復(fù)現(xiàn)結(jié)果分享
在讀完論文后的第一時(shí)間,我就嘗試了復(fù)現(xiàn)工作。我的復(fù)現(xiàn)結(jié)果表明,Poly-encoder 不管是收斂速度還是模型上限,都要顯著優(yōu)于 Bi-encoder,且 Poly-encoder 幾乎不增加額外的顯存負(fù)擔(dān),對(duì)訓(xùn)練速度的影響也幾乎可以忽略。完整代碼位于:
https://github.com/sfzhou5678/PolyEncoder
5.1 關(guān)鍵代碼分析
Poly-encoder 的實(shí)現(xiàn)非常簡(jiǎn)單,只需在 Bi-encoder 的基礎(chǔ)上略加修改即可。接下來(lái)我將介紹實(shí)現(xiàn) Poly-encoder 的核心代碼。
我們首先用 nn.embedding 來(lái)作為 m 個(gè) poly_codes 的值, 然后 forward 的時(shí)候根據(jù)m的值產(chǎn)生對(duì)應(yīng)個(gè)數(shù)的 poly_codes,這些 codes ?將用于計(jì)算不同的? attention weights,以產(chǎn)生多個(gè) vec_ctxt(即 vec_q)。
這里我令 poly_code_ids+=1 是為了讓 context_encoder 和 response_encoder 對(duì)稱,所以把 0 號(hào) id 留給了 response_encoder。
self.poly_code_embeddings?=?nn.Embedding(self.poly_m?+?1,?config.hidden_size)poly_code_ids?=?torch.arange(self.poly_m,?dtype=torch.long,?device=context_input_ids.device) poly_code_ids?+=?1 poly_code_ids?=?poly_code_ids.unsqueeze(0).expand(batch_size,?self.poly_m) poly_codes?=?self.poly_code_embeddings(poly_code_ids)接著,我們用這些 poly_codes 和 bert 的輸出做 attention 得到 context_vecs:
def?dot_attention(q,?k,?v,?v_mask=None,?dropout=None):attention_weights?=?torch.matmul(q,?k.transpose(-1,?-2))if?v_mask?is?not?None:attention_weights?*=?v_mask.unsqueeze(1)attention_weights?=?F.softmax(attention_weights,?-1)if?dropout?is?not?None:attention_weights?=?dropout(attention_weights)output?=?torch.matmul(attention_weights,?v)return?outputstate_vecs?=?self.bert(context_input_ids,?context_input_masks,?context_segment_ids)[0]??#?[bs,?length,?dim] context_vecs?=?dot_attention(poly_codes,?state_vecs,?state_vecs,?context_input_masks,?self.dropout)?#[bs,?m,?dim]得到 response_vec 的方式類似,不再贅述。最后,只需根據(jù) response_vec 給 context_vecs 做一次 attention 得到 final_context_vec 即可:
if?labels?is?not?None:responses_vec?=?responses_vec.view(1,?batch_size,?-1).expand(batch_size,?batch_size,?self.vec_dim)final_context_vec?=?dot_attention(responses_vec,?context_vecs,?context_vecs,?None,?self.dropout)在 loss function 方面,雖然我們可以在準(zhǔn)備數(shù)據(jù)的時(shí)候就為每個(gè)樣本做 N 次負(fù)采樣,但是在 Bi-encoder 或 Poly-encoder 這種產(chǎn)生 response_vec 和 query 完全獨(dú)立的場(chǎng)景下,可以將同一個(gè) batch 內(nèi)的其他 response 作為負(fù)樣本來(lái)避免重復(fù)計(jì)算,有效提升訓(xùn)練效率。
具體實(shí)現(xiàn)時(shí),我們計(jì)算 context_vec_i 和 response_vec_j 的點(diǎn)乘,從而產(chǎn)生一個(gè) [bs, bs] 的余弦相似度矩陣,這個(gè)相似度矩陣就是 context_vec_i 和 batch 內(nèi)的每一個(gè) response_vec 的相似度。
由于我們的目標(biāo)是最大化 context_vec_i 和對(duì)應(yīng)的正樣本,即 response_vec_i 的相似度,所以我們可以做一個(gè) [bs,bs] 的單位矩陣作為 label,最后應(yīng)用交叉熵產(chǎn)生訓(xùn)練用的 loss。
我的代碼中在 dot_product 后面還乘了系數(shù) 5,這就是一個(gè)用于緩和 softmax 取值的參數(shù),其具體取值通常需要實(shí)驗(yàn)來(lái)確定,這里的 5 只是我的經(jīng)驗(yàn)值。
#?因?yàn)橐阌嘞蚁嗨贫?#xff0c;所以給向量都?xì)w一化一下,之后直接點(diǎn)乘即可 context_vec?=?F.normalize(context_vec,?2,?-1) responses_vec?=?F.normalize(responses_vec,?2,?-1)responses_vec?=?responses_vec.squeeze(1) dot_product?=?torch.matmul(context_vec,?responses_vec.t())??#?[bs,?bs] mask?=?torch.eye(context_input_ids.size(0)).to(context_input_ids.device) loss?=?F.log_softmax(dot_product?*?5,?dim=-1)?*?mask loss?=?(-loss.sum(dim=1)).mean()5.2 實(shí)驗(yàn)結(jié)果
我使用的實(shí)驗(yàn)數(shù)據(jù)是論文中所用的 Ubuntu V2,實(shí)驗(yàn)設(shè)備是我筆記本上的一個(gè) 1066 顯卡。當(dāng)然為了實(shí)驗(yàn)跑得更快,我沒(méi)有使用論文中所用的 bert-base,而是一個(gè)預(yù)訓(xùn)練過(guò)的僅 4 層的 bert-small。
另外,此實(shí)驗(yàn)中所用的 batchsize、文本長(zhǎng)度、歷史對(duì)話信息等都限制的比較小(不然實(shí)驗(yàn)實(shí)在是跑得太慢了),因此實(shí)驗(yàn)結(jié)果整體會(huì)較原論文中偏低。
最終的實(shí)驗(yàn)設(shè)置和結(jié)果如下:
Dataset: Ubuntu V2
Device: GTX 1060 6G x1
Pretrained model:?BERT-small-uncased (https://storage.googleapis.com/bert_models/2020_02_20/all_bert_models.zip)
Batch size:?32
max_contexts_length: 128
max_context_cnt: 4
max_response_length:64
lr: 5e-5
Epochs: 3
Results:
▲ 復(fù)現(xiàn)實(shí)驗(yàn)結(jié)果匯總
從上表中明顯可以看出,Poly-encoder 的效果要遠(yuǎn)優(yōu)于 Bi-encoder 的,當(dāng)使用 16 個(gè) codes 時(shí),poly 較 bi 的提升可得到 2.24 個(gè)點(diǎn),而使用 64、360 個(gè) codes 時(shí)提升分別可達(dá) 3.12 和 3.52 個(gè)點(diǎn)。而且模型的訓(xùn)練速度幾乎沒(méi)有受到影響,同時(shí)對(duì)顯存的負(fù)擔(dān)也非常小。
總結(jié)
本文提出的 Poly-encoder 思路非常清晰,實(shí)現(xiàn)難度不高,而且實(shí)驗(yàn)效果非常理想,我個(gè)人非常喜歡!
Poly-encoder 架構(gòu)還有一個(gè)突出優(yōu)點(diǎn)在于,它可以很輕松地拓展到大量信息檢索相關(guān)的領(lǐng)域,無(wú)論是搜索、推薦,或是 CV 領(lǐng)域的 ReID 等,只要可以產(chǎn)生 query 和 candidates 的向量 vec_q 和 vec_c,那么都有可能成功應(yīng)用 Poly-encoder。
我自己十分看好 Poly-encoder,相信在未來(lái)它會(huì)成為和 DSSM 一樣的經(jīng)典必讀論文。
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
WWW 2020 開(kāi)源論文 | 異構(gòu)圖Transformer
Transformer的七十二變
抽取+生成:基于背景知識(shí)的參考感知網(wǎng)絡(luò)對(duì)話模型
二值神經(jīng)網(wǎng)絡(luò)(Binary Neural Networks)最新綜述
文本分類和序列標(biāo)注“深度”實(shí)踐
BERT在多模態(tài)領(lǐng)域中的應(yīng)用
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
?????來(lái)稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來(lái)稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結(jié)
以上是生活随笔為你收集整理的ICLR 2020 | 可提速3000倍的全新信息匹配架构(附代码复现)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 个人住房贷款调整 10月8日起开时实施
- 下一篇: 从近年CVPR看域自适应立体匹配