Transformer-XL解读(论文 + PyTorch源码)
前言
目前在NLP領域中,處理語言建模問題有兩種最先進的架構:RNN和Transformer。RNN按照序列順序逐個學習輸入的單詞或字符之間的關系,而Transformer則接收一整段序列,然后使用self-attention機制來學習它們之間的依賴關系。這兩種架構目前來看都取得了令人矚目的成就,但它們都局限在捕捉長期依賴性上。
為了解決這一問題,CMU聯合Google Brain在2019年1月推出的一篇新論文《Transformer-XL:Attentive Language Models beyond a Fixed-Length Context》同時結合了RNN序列建模和Transformer自注意力機制的優點,在輸入數據的每個段上使用Transformer的注意力模塊,并使用循環機制來學習連續段之間的依賴關系。Transformer-XL在多種語言建模數據集(如單詞級別的enwik8和字符級別的text8)上實現了目前的SoTA效果,且該模型在推理階段速度更快,比之前最先進的利用Transformer進行語言建模的方法快300~1800倍。?同時,該論文也放出了其配套源碼(包括TensorFlow和PyTorch的)、預訓練模型及在各個數據集上訓練的超參數,可以說是非常良心了~造福我等伸手黨!
本文將主要針對模型原理及其PyTorch實現進行逐一對照解讀,因筆者能力有限,如有不詳盡之處,可移步文末的傳送門進行詳細閱讀,并歡迎指出~
?
文章目錄
-
- 前言
- 一. 回顧Transformer
- 二. vanilla Transformer
- 三. Transformer-XL
-
- 1. 引入循環機制
- 2. 相對位置編碼
- 3. 整體計算公式
- 四. PyTorch實現
- 五. 實驗結果
-
- 1. 語言建模指標
- 2. 兩個創新點的優勢
- 3. 測試階段的速度
- 六. 總結
-
- 1. 模型特點
- 2. 優點
- 3. 不足
- 傳送門
?
一. 回顧Transformer
在NLP領域中,一種對語言建模的最常用模型就是RNN,它可以捕捉單詞之間的依賴關系。但因為梯度消失和爆炸的問題,RNN變得非常難以訓練,LSTM單元和梯度裁剪方法的提出也不足以解決此類問題。同時RNN網絡的計算速度往往很慢,其學習長期依賴的能力也較為有限(論文中提到,LSTM語言模型平均只能建模200個上下文詞語)。
2017年6月,Google Brain在論文《Attention Is All You Need》中提出的Transformer架構,完全摒棄了RNN的循環機制,采用一種self-attention的方式進行全局處理。其接收一整段序列,并使用三個可訓練的權重矩陣——Query、Key和Value來一次性學習輸入序列中各個部分之間的依賴關系。Transformer網絡由多個層組成,每個層都由多頭注意力機制和前饋網絡構成。由于在全局進行注意力機制的計算,忽略了序列中最重要的位置信息。Transformer為輸入添加了位置編碼(Positional Encoding),使用正弦函數完成,為每個部分的位置生成位置向量,不需要學習,用于幫助網絡學習其位置信息。其示意如下圖所示:
有關Transformer的更深入討論,可參考筆者之前的博客:
Transformer(論文 + PyTorch源碼解讀)
二. vanilla Transformer
為何要提這個模型?因為Transformer-XL是基于這個模型進行的改進。
Al-Rfou等人基于Transformer提出了一種訓練語言模型的方法(?https://arxiv.org/abs/1808.04444?),來根據之前的字符預測片段中的下一個字符。例如,它使用x 1 , x 2 , . . . , x n ? 1 x_1, x_2, ..., x_{n-1}x1?,x2?,...,xn?1?預測字符x n x_nxn?,而在x n x_nxn?之后的序列則被mask掉。論文中使用64層模型,并僅限于處理 512個字符這種相對較短的輸入,因此它將輸入分成段,并分別從每個段中進行學習,如下圖所示。 在測試階段如需處理較長的輸入,該模型會在每一步中將輸入向右移動一個字符,以此實現對單個字符的預測。
該模型在常用的數據集如enwik8和text8上的表現比RNN模型要好,但它仍有以下兩個缺點:
a. 上下文長度受限:字符之間的最大依賴距離受輸入長度的限制,模型看不到出現在幾個句子之前的單詞。
b. 上下文碎片:對于長度超過512個字符的文本,都是從頭開始單獨訓練的。段與段之間沒有上下文依賴性,會讓訓練效率低下,也會影響模型的性能。
c. 推理速度慢:在測試階段,每次預測下一個單詞,都需要重新構建一遍上下文,并從頭開始計算,這樣的計算速度非常慢。
三. Transformer-XL
Transformer-XL架構在vanilla Transformer的基礎上引入了兩點創新:循環機制(Recurrence Mechanism)和相對位置編碼(Relative Positional Encoding),以克服vanilla Transformer的缺點。與vanilla Transformer相比,Transformer-XL的另一個優勢是它可以被用于單詞級和字符級的語言建模。
1. 引入循環機制
與vanilla Transformer的基本思路一樣,Transformer-XL仍然是使用分段的方式進行建模,但其與vanilla Transformer的本質不同是在于引入了段與段之間的循環機制,使得當前段在建模的時候能夠利用之前段的信息來實現長期依賴性。如下圖所示:
在訓練階段,處理后面的段時,每個隱藏層都會接收兩個輸入:
這兩個輸入會被拼接,然后用于計算當前段的Key和Value矩陣。對于某個段的某一層的具體計算公式如下:
其中,τ \tauτ表示第幾段,n nn表示第幾層,h hh表示隱層的輸出。S G ( ? ) SG(·)SG(?)表示停止計算梯度,[ h u ° h v ] [h_u \circ h_v][hu?°hv?]表示在長度維度上的兩個隱層的拼接,W . W_.W.?是模型參數。乍一看與Transformer中的計算公式很像,唯一關鍵的不同就在于Key和Value矩陣的計算上,即k τ + 1 n k_{\tau+1}^nkτ+1n?和v τ + 1 n v_{\tau + 1}^nvτ+1n?,它們基于的是擴展后的上下文隱層狀態h ~ τ + 1 n ? 1 \tilde{h}_{\tau+1}^{n-1}h~τ+1n?1?進行計算,h τ n ? 1 {h}_{\tau}^{n-1}hτn?1?是之前段的緩存。
原則上只要GPU內存允許,該方法可以利用前面更多段的信息,測試階段也可以獲得更長的依賴。
在測試階段,與vanilla Transformer相比,其速度也會更快。在vanilla Transformer中,一次只能前進一個step,并且需要重新構建段,并全部從頭開始計算;而在Transformer-XL中,每次可以前進一整個段,并利用之前段的數據來預測當前段的輸出。
2. 相對位置編碼
在Transformer中,一個重要的地方在于其考慮了序列的位置信息。在分段的情況下,如果僅僅對于每個段仍直接使用Transformer中的位置編碼,即每個不同段在同一個位置上的表示使用相同的位置編碼,就會出現問題。比如,第i ? 2 i-2i?2段和第i ? 1 i-1i?1段的第一個位置將具有相同的位置編碼,但它們對于第i ii段的建模重要性顯然并不相同(例如第i ? 2 i-2i?2段中的第一個位置重要性可能要低一些)。因此,需要對這種位置進行區分。
論文對于這個問題,提出了一種新的位置編碼的方式,即會根據詞之間的相對距離而非像Transformer中的絕對位置進行編碼。在Transformer中,第一層的計算查詢q i T q_i^TqiT?和鍵k j k_jkj?之間的attention分數的方式為:
其中,E x i E_{x_i}Exi??是詞i ii的embedding,E x j E_{x_j}Exj??是詞j jj的embedding,U i U_iUi?和U j U_jUj?是位置向量,這個式子實際上是( W q ( E x i + U i ) ) T ? ( W k ( E x j + U j ) ) (W_q(E_{x_i}+U_i))^T·(W_k(E_{x_j}+U_j))(Wq?(Exi??+Ui?))T?(Wk?(Exj??+Uj?))的展開,就是Transformer中的標準格式。
在Transformer-XL中,對上述的attention計算方式進行了變換,轉為相對位置的計算,而且不僅僅在第一層這么計算,在每一層都是這樣計算。
對比來看,主要有三點變化:
從另一個角度來解讀這個公式的話,可以將attention的計算分為如下四個部分:
a. 基于內容的“尋址”,即沒有添加原始位置編碼的原始分數。
b. 基于內容的位置偏置,即相對于當前內容的位置偏差。
c. 全局的內容偏置,用于衡量key的重要性。
d. 全局的位置偏置,根據query和key之間的距離調整重要性。
3. 整體計算公式
結合上面兩個創新點,將Transformer-XL模型的整體計算公式整理如下,這里考慮一個N層的只有一個注意力頭的模型:
其中,τ \tauτ代表第幾段,n nn代表第幾層,h τ 0 : = E s τ h_\tau^0 := E_{s_\tau}hτ0?:=Esτ??定義為第τ \tauτ段的詞向量序列。值得一提的是,計算A AA矩陣的時候,需要對所有的i ? j i-ji?j計算W k , R n R i ? j W_{k,R}^nR_{i-j}Wk,Rn?Ri?j?,如果直接按照公式計算的話,計算時間是O ( l e n g t h ) 2 O(length)^2O(length)2,而實際上i ? j i-ji?j的范圍只從0 ~ length,因此可以先計算好這length個向量,然后在實際計算A AA矩陣時直接取用即可。
具體的,設M MM和L LL分別為memory和當前段序列的長度,則i ? j i-ji?j的范圍也就為0 ~?M + L ? 1 M + L - 1M+L?1。下面的Q QQ矩陣中的每一行都代表著W k , R R i ? j W_{k,R}R_{i-j}Wk,R?Ri?j?中一個i ? j i-ji?j的可能性,即Q k = W k , R R M + L ? 1 ? k Q_k = W_{k, R} R_{M+L-1-k}Qk?=Wk,R?RM+L?1?k?。
則對于上面公式中的(b)項,即q i T W k , R R i ? j q_i^TW_{k,R}R_{i-j}qiT?Wk,R?Ri?j?,其構成的所有可能向量的矩陣為B BB矩陣,其形狀為L ? ( M + L ) L * (M + L)L?(M+L),這是我們最終需要的(b)項的attention結果。
我們進一步定義B ~ \tilde{B}B~矩陣為如下:
可見,需要的B BB矩陣的每一行只是B ~ \tilde{B}B~的向左shift而已。因此,可以直接利用矩陣乘法計算B ~ \tilde{B}B~即可。設R i ? j R_{i-j}Ri?j?的維度為d R d_RdR?,q i q_iqi?的維度為d q d_qdq?,W k , R W_{k,R}Wk,R?矩陣的維度為d q ? d R d_q * d_Rdq??dR?,則直接計算矩陣B的時間復雜度為2 ? d q ? d R ? L ? ( M + L ) 2* d_q * d_R * L * (M+L)2?dq??dR??L?(M+L),而計算B ~ \tilde{B}B~的時間復雜度為L ? d q ? ( M + L ) + d q ? d R ? ( M + L ) L * d_q * (M + L) + d_q * d_R * (M + L)L?dq??(M+L)+dq??dR??(M+L),計算量明顯不是一個量級(后者要快很多)。
同理,對于(d)項來說,可以對所有的i ? j i-ji?j定義需要的矩陣D DD為L ? ( M + L ) L * (M+L)L?(M+L):
可以用如下的d ~ \tildeze8trgl8bvbqd~來進行shift得到:
其中Q QQ矩陣已經計算過了,也可以在這一步減少計算量。
四. PyTorch實現
筆者在這里主要研究的是核心模型部分,將針對關鍵的實現細節進行剖析,想要看完整代碼的讀者請戳這里。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
這里的demb是相對位置編碼的維度,pos_seq是序列的位置向量,在代碼里面是torch.arange(klen-1, -1, -1.0),其中的klen是mlen+qlen,從名稱和之前的原理介紹可知這里的mlen是memory的長度,qlen是query的長度,這兩者組成了key的長度。最終返回的即是R RR向量矩陣,可見是不需要學習的。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
其中n_head,d_model,d_head分別表示注意力頭的個數,模型的隱層維度,每個頭的隱層維度。qkv_net是用于計算query、key和value變換的參數矩陣W q , W k , E , W v W_{q}, W_{k,E}, W_{v}Wq?,Wk,E?,Wv?,與標準的Transformer中一致,o_net是用于將所有注意力頭的結果拼接后再變換到模型維度的參數矩陣,layer_norm是LayerNormalization層,r_net是用于計算relative position embedding變換的參數矩陣W k , R W_{k,R}Wk,R?。
在前向計算的過程中,w和r分別是上一層的輸出以及RelativePositionEmbedding,r_w_bias和r_r_bias分別是u uu向量和v vv向量,AC是前面公式中的(a)項和(c)項,BD是前面公式中的(b)項和(d)項,根據前面講的快速計算帶有相對位置的項,這里的BD需要進行偏移,即_rel_shift,經過筆者的演算,發現這里經過此函數后的BD并不是想要的B BB矩陣,其在B BB矩陣的(M+1)對角線(設主對角線為0,正數即為向右上偏移的量)的右上還有元素,不過后面緊接著就進行了mask。這里的attn_mask即為torch.triu(word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]。再往后就是標準的Transformer中的add&norm環節了,就不再贅述。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
這里的hids是當前段每層的輸出,mems為當前段每層依賴的memory,qlen為序列長度,mlen為當前段依賴的memory的長度。
從代碼來看的話,前面的循環示意圖似乎有些問題?感覺在訓練階段,對于每個段里面的第二個位置開始的點,都應該連到第一個位置連到的最前面memory?因為用的是同樣長度的memory。
五. 實驗結果
1. 語言建模指標
在最關心的語言模型建模指標上,論文比較了模型在單詞級別和字符級別上不同數據集的表現,并且與RNN和(vanilla) Transformer都做了比較。實驗證明,Transformer-XL在各個不同的數據集上均實現了目前的SoTA:在大型單詞級別數據集WikiText-103上,Transformer-XL將困惑度從20.5降到18.3;在enwiki8數據集上,12層Transformer-XL的bpc達到了1.06,相同bpc的AI-Rfou的模型(?https://arxiv.org/abs/1808.04444?)參數量卻是6倍,24層Transformer-XL的bpc更是達到了0.99;在One Billion Word數據集上(僅具有短句的)和Penn Treebank數據集上(小型,僅有1M)也取得了SoTA的效果,前者的困惑度從23.7到21.8,后者的困惑度從55.3到54.5。表明了Transformer-XL在各個數據集下的不俗競爭力。
2. 兩個創新點的優勢
下圖比較了不同上下文長度(即memory的長度)中包不包含循環機制、以及使不使用新位置編碼方式的困惑度得分。可見,使用循環機制和相對位置編碼的Transformer-XL明顯優于其他的模型,并且能夠有效利用長期依賴性,而且它能捕獲超出RNN 80%的依賴性,和超出Transformer 450%的依賴性。
3. 測試階段的速度
Transformer-XL的推理速度也明顯快于vanilla Transformer,尤其是對于較長的上下文。比如,在上下文長度為800時,Transformer-XL提速363倍;而當上下文長度增加到3800時,Transformer-XL提速1874倍!
六. 總結
1. 模型特點
在 AI-Rfou 等人提出的vanilla Transformer上做了兩點創新:
2. 優點
3. 不足
傳送門
論文:https://arxiv.org/pdf/1901.02860.pdf
代碼:https://github.com/kimiyoung/transformer-xl
參考:https://www.lyrn.ai/2019/01/16/transformer-xl-sota-language-model
總結
以上是生活随笔為你收集整理的Transformer-XL解读(论文 + PyTorch源码)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: The Illustrated Tran
- 下一篇: BERT Word Embeddings