【自然语言处理】Transformer 讲解
有任何的書寫錯(cuò)誤、排版錯(cuò)誤、概念錯(cuò)誤等,希望大家包含指正。
在閱讀本篇之前建議先學(xué)習(xí):
【自然語(yǔ)言處理】Seq2Seq 講解
【自然語(yǔ)言處理】Attention 講解
Transformer
為了講解更加清晰,約定“預(yù)測(cè)階段”被稱為“推斷階段”(inference),“預(yù)測(cè)”用于表示模型根據(jù)輸入信息輸出目標(biāo)信息的抽象過(guò)程。
1. 簡(jiǎn)介
在 Transformer 出現(xiàn)之前,大部分序列轉(zhuǎn)換(轉(zhuǎn)錄)模型是基于 RNNs 或 CNNs 的 Encoder-Decoder 結(jié)構(gòu)。但是 RNNs 固有的順序性質(zhì)使得并行計(jì)算難以實(shí)現(xiàn),即訓(xùn)練時(shí)當(dāng)前時(shí)刻的隱藏狀態(tài)與前一個(gè)時(shí)刻的隱藏狀態(tài)有關(guān),這意味著需要先計(jì)算出前一個(gè)時(shí)刻的狀態(tài)才能計(jì)算下一個(gè)時(shí)刻的狀態(tài),這大大限制了 RNNs 的訓(xùn)練速度;CNNs 可以比較好的解決并行計(jì)算的問(wèn)題,但是對(duì)于長(zhǎng)序列 CNNs 難以建模,需要設(shè)置非常多的卷積層才能將較長(zhǎng)距離的部分聯(lián)系起來(lái),可以想象,大小合理的卷積核只能對(duì)序列的某一部分進(jìn)行關(guān)聯(lián),當(dāng)卷積層疊加到一定層數(shù)后,才能將序列中最遠(yuǎn)距離的兩個(gè)部分相關(guān)聯(lián)。Transformer 徹底摒棄了 RNNs 和 CNNs,是一種完全基于注意力機(jī)制的 Encode-Decoder 模型。它的預(yù)測(cè)效果非常出眾,同時(shí)優(yōu)質(zhì)的并行性使得它的訓(xùn)練時(shí)間更短。
序列轉(zhuǎn)換(轉(zhuǎn)錄)模型(sequence transduction models)是指將一個(gè)序列轉(zhuǎn)換為另一個(gè)序列的模型。
2. 模型結(jié)構(gòu)
2.1. 整體框架
上面提到 Transformer 是序列轉(zhuǎn)換模型,將其視為一個(gè)黑盒子,那么輸入和輸出分別是兩個(gè)序列,如圖 111 所示。
圖 1????Transformer 模型 (黑盒模型)
將黑盒子一步步打開(kāi)來(lái)看內(nèi)部細(xì)節(jié)。
Transformer 仍然采用 Encoder-Decoder 框架,如圖 222 所示。
圖 2????Transformer 模型 (Encoder-Decoder 框架)
其中,ENCODERS 部分由多層編碼器首尾相連組成,DECODERS 同理;二者之間的連接是將 ENCODERS 最后一層編碼器的輸出作為 DECODERS 每一層編碼器的部分輸入。論文中作者規(guī)定 ENCODERS 和 DECODERS 均為 6 層,如圖 333 所示。
圖 3????Transformer 模型 (6 層編碼器和解碼器)
圖 444 展示了更為具體的內(nèi)部細(xì)節(jié)。每個(gè)編碼器包括兩個(gè)子層:Multi-Head self-attention(多頭自注意力)子層和 Position-wise Feed Forward Network(按位置操作的前饋神經(jīng)網(wǎng)絡(luò),簡(jiǎn)記 FFN)子層。注意到,兩個(gè)子層除了主體模塊外,還包括 Add & Norm 模塊。Add 和 Norm 分別對(duì)應(yīng) Residual Connection 和 Layer Normalization,即殘差連接和層標(biāo)準(zhǔn)化。解碼器在編碼器的兩個(gè)子層的基礎(chǔ)上還在兩層之間添加了一層樸素的多頭注意力子層,即非自注意力子層。
圖 4????Transformer 中的編碼器和解碼器
對(duì)于圖 444 的理解需要配合圖 333。
2.2. 自注意力機(jī)制
2.2.1. 作用
對(duì)于兩個(gè)僅有一詞不同的語(yǔ)句,The animal didn’t cross the street because it was too tired 和 The animal didn’t cross the street because it was too wide,我們可以輕松判斷出第一個(gè)語(yǔ)句中的 it 表示 animal,第二個(gè)語(yǔ)句中的 it 表示 street。從注意力的角度來(lái)看,第一個(gè)語(yǔ)句中的 it 會(huì)更加關(guān)注 animal,第二個(gè)語(yǔ)句中的 it 會(huì)更加關(guān)注 street。
Query、Key 和 Value 均來(lái)自同一個(gè)輸入序列的注意力機(jī)制稱為自注意力機(jī)制(self-attention mechanism)。自注意力機(jī)制對(duì)于序列轉(zhuǎn)換任務(wù)的意義在于,幫助模型理解序列中元素之間的關(guān)系,亦語(yǔ)句中單詞的語(yǔ)義關(guān)系。正如上面的例子,經(jīng)過(guò)自注意力模塊,模型可以學(xué)習(xí)到不同語(yǔ)境下的 it 與句子中不同的單詞有關(guān),這使得模型對(duì)語(yǔ)義的理解更加深刻。
2.2.2. 計(jì)算
自注意力計(jì)算過(guò)程如下:
計(jì)算過(guò)程(簡(jiǎn)化,省略縮放操作)如圖 555 所示。
圖 5????自注意力計(jì)算
2.2.3. 多頭注意力
多頭的思想來(lái)自多通道卷積,通過(guò)不同通道識(shí)別不同的模式,從不同角度理解序列。還是以上面提到的語(yǔ)句為例,The animal didn’t cross the street because it was too tired,不同的注意力頭關(guān)注序列中不同的位置。如圖 666 所示。
區(qū)別多頭注意力和自注意力,前者強(qiáng)調(diào)通道個(gè)數(shù),后者強(qiáng)調(diào) Q、K 和 V 的取法。二者是可以融合使用的,這在 Transformer 中有很多體現(xiàn)。
圖 6????雙頭自注意力
可以看出,橙色注意力頭更加關(guān)注 it 所指代的單詞 animal,綠色注意力頭更加關(guān)注 it 所指代單詞的狀態(tài) tired。
以雙頭自注意力為例,符號(hào)上標(biāo)表示頭序號(hào),第一個(gè)單詞注意力的計(jì)算過(guò)程(簡(jiǎn)化,省略 Softmax 層)如圖 777 所示。
圖 7????第一個(gè)單詞的雙頭自注意力計(jì)算
對(duì)比圖 555 和圖 777 可以發(fā)現(xiàn),單頭注意力僅涉及一組參數(shù)矩陣 WQW_QWQ?、WKW_KWK? 和 WVW_VWV?,hhh頭注意力包括 hhh 組參數(shù)矩陣 WQ1:hW_Q^{1:h}WQ1:h?、WK1:hW_K^{1:h}WK1:h? 和 WV1:hW_V^{1:h}WV1:h?。
在 Transformer 模型中,多頭注意力保證參數(shù)數(shù)量保持不變,即 WQW_QWQ? 元素個(gè)數(shù)與 WQ1:hW_Q^{1:h}WQ1:h? 元素個(gè)數(shù)相等,WKW_KWK? 元素個(gè)數(shù)與 WK1:hW_K^{1:h}WK1:h? 元素個(gè)數(shù)相等,WVW_VWV? 元素個(gè)數(shù)與 WV1:hW_V^{1:h}WV1:h? 元素個(gè)數(shù)相等,這表明 WQiW_Q^iWQi? 的列數(shù)是 WQW_QWQ? 的 1/h1/h1/h,WKiW_K^iWKi? 和 WViW_V^iWVi? 同理。
在代碼實(shí)現(xiàn)時(shí),完全可以將 WQ1:hW^{1:h}_QWQ1:h? 拼接成 WQW_QWQ? 的樣式實(shí)現(xiàn)并行計(jì)算多頭注意力。具體地,XWQ1:h=Q1:hXW_Q^{1:h}=Q^{1:h}XWQ1:h?=Q1:h,Q1:hQ^{1:h}Q1:h 由 Q1=(Q11,Q21,Q31)Q^1=(Q_1^1,Q_2^1,Q_3^1)Q1=(Q11?,Q21?,Q31?)、Q2=(Q12,Q22,Q32)Q^2=(Q_1^2,Q_2^2,Q_3^2)Q2=(Q12?,Q22?,Q32?)、…\dots… 、Qh=(Q1h,Q2h,Q3h)Q^h=(Q^h_1,Q_2^h,Q_3^h)Qh=(Q1h?,Q2h?,Q3h?) 拼接而成,其中 X=(x1;x2;x3)X=(x_1;x_2;x_3)X=(x1?;x2?;x3?),K1:hK^{1:h}K1:h 和 V1:hV^{1:h}V1:h 的含義及計(jì)算同理。這樣多頭注意力的計(jì)算就完全轉(zhuǎn)換成了單頭注意力的計(jì)算。顯然,最后計(jì)算出的注意力也是由多個(gè)頭的注意力拼接得到,即 Z=(z1,z2,…,zh)Z=(z^1,z^2,\dots, z^h)Z=(z1,z2,…,zh),zi=(z1i,z2i,z3i)z^i=(z_1^i,z_2^i,z_3^i)zi=(z1i?,z2i?,z3i?)。最后,為了融合每個(gè)頭得到的信息,ZZZ 還需要經(jīng)過(guò)一層同維線性映射(映射矩陣的行數(shù)與列數(shù)相等)。
多頭注意力所謂的“參數(shù)數(shù)量不變”不包括最后的融合線性映射的參數(shù)。
多頭自注意力模塊的理論示意圖(并非具體實(shí)現(xiàn)的方法)如圖 888 所示。可以發(fā)現(xiàn),自注意力模塊的輸入和輸出序列長(zhǎng)度一致,且向量 xix_ixi? 和 ziz_izi? 的維度一致,我們記其維度為 dmodel=512d_{\rm model}=512dmodel?=512。輸入與輸出長(zhǎng)度和維度相等的模塊貫穿整個(gè) Transformer 模型,編碼器和解碼器由這種模塊組成的好處在于前一個(gè)編碼器(或解碼器)的輸出可以直接作為下一個(gè)編碼器(或解碼器)的輸入,而且最后一個(gè)編碼器的輸出也可以直接輸入到每個(gè)解碼器中,這在一定程度上簡(jiǎn)化了模型。
圖 8????多頭自注意力模塊的理論示意圖
2.3. 按位置操作的前饋神經(jīng)網(wǎng)絡(luò)
參考圖 444 可以知道,在忽略 Add & Norm 層的前提下,無(wú)論是在編碼器還是解碼器中,按位置操作的前饋神經(jīng)網(wǎng)絡(luò)(Position-wise Feed-Forward Networks,FFN)都是以注意力模塊的輸出 z1z_1z1?,z2z_2z2?,z3z_3z3? 作為輸入。可以認(rèn)為 FFN 是單隱層的感知機(jī),FFN 的輸入為單個(gè)單詞對(duì)應(yīng)的注意力向量 ziz_izi?,先映射到高維,再映射回原來(lái)的維度輸出。輸入層神經(jīng)元個(gè)數(shù)與注意力向量的維度相等,即 dmodel=512d_{\rm model}=512dmodel?=512,隱藏層神經(jīng)元個(gè)數(shù)為 4×dmodel=20484\times d_{\rm model}=20484×dmodel?=2048,輸出層神經(jīng)元個(gè)數(shù)為 dmodel=512d_{\rm model} = 512dmodel?=512。可見(jiàn),FFN 層也滿足輸入和輸出的維度相等。
有兩點(diǎn)需要注意。一是,注意到 FFN 的輸入層神經(jīng)元個(gè)數(shù)與單個(gè)單詞注意力向量維度一致,這說(shuō)明對(duì)于某個(gè) FFN 層而言,雖然一次性接收來(lái)自注意力模塊的多個(gè)單詞注意力向量,但不會(huì)一次性輸入到多層感知機(jī)中,否則輸入層神經(jīng)元個(gè)數(shù)應(yīng)該是注意力向量維度乘以序列長(zhǎng)度,即對(duì)于長(zhǎng)度為 333 的序列,輸入層神經(jīng)元個(gè)數(shù)為 3×dmodel3\times d_{\rm model}3×dmodel?。在代碼實(shí)現(xiàn)時(shí),為了并行處理,會(huì)將全部單詞向量堆疊成的矩陣作為多層感知機(jī)的輸入。另外一種更簡(jiǎn)單的理解方式是,注意力模塊輸出的每個(gè)單詞注意力向量都對(duì)應(yīng)一個(gè)多層感知機(jī),只不過(guò)這些同結(jié)構(gòu)的多層感知機(jī)共享參數(shù),這正是 FNN 稱為 Position-wise 的原因,即每個(gè)單詞對(duì)應(yīng)一個(gè)位置,每個(gè)位置對(duì)應(yīng)一個(gè)前饋神經(jīng)網(wǎng)絡(luò)。兩種理解方式如圖 999 和 101010 所示。二是,FFN 隱藏層的輸出會(huì)經(jīng)過(guò) ReLU 激活函數(shù),ReLU 激活函數(shù)提供非線性變換,FFN 對(duì)應(yīng)的公式為:
FFN(zi)=max?(0,ziW1+b1)W2+b2{\rm FFN}(z_i) = \max (0, z_i W_1+b_1)W_2+b_2 FFN(zi?)=max(0,zi?W1?+b1?)W2?+b2?
圖 9????FFN 示意圖(一)
圖 10????FFN 示意圖(二)
2.4. 殘差連接和層規(guī)范化
個(gè)人習(xí)慣,這里統(tǒng)一稱為“規(guī)范化”,而不稱為“標(biāo)準(zhǔn)化”或“歸一化”。
Add & Norm 層更嚴(yán)謹(jǐn)?shù)拿Q應(yīng)該是 Residual Connection & Layer Normalization,本層的任務(wù)是進(jìn)行殘差連接和層規(guī)范化(簡(jiǎn)記為 LN)。假設(shè) Add & Norm 層的前一層的輸入記為 X=(x1;x2;x3)X=(x_1;x_2;x_3)X=(x1?;x2?;x3?),輸出記為 Y=(y1;y2;y3)Y=(y_1;y_2;y_3)Y=(y1?;y2?;y3?)。Add & Norm 層的具體任務(wù)是對(duì) X+YX+YX+Y 進(jìn)行層規(guī)范化。
殘差連接的思想來(lái)自 ResNet,屬于非常基礎(chǔ)的知識(shí),這里不再介紹。重點(diǎn)講解一下 BN(Batch Normalization)和 LN(Layer Normalization)。
2.4.1. Batch Normalization
對(duì)于由向量表示的樣本而言,一個(gè) batch 對(duì)應(yīng)一個(gè)矩陣(B×DB\times DB×D),比如以身高、體重等為特征來(lái)描述一個(gè)人,抽象為如圖 11(左)11\;(左)11(左) 所示的矩形;對(duì)于由矩陣表示的樣本而言,一個(gè) batch 是由多個(gè)樣本矩陣堆疊而成的張量(B×N×DB\times N\times DB×N×D),常見(jiàn)于自然語(yǔ)言處理任務(wù)中,張量的第一維表示序列個(gè)數(shù),第二維表示序列長(zhǎng)度,第三維表示序列中每個(gè)元素對(duì)應(yīng)向量的維度,抽象為如圖 11(右)11\;(右)11(右) 所示的立方體。
圖 11????矩陣 batch (左) 和張量 batch (右)
對(duì)于矩陣 batch 而言,Batch Normalization 以 batch 內(nèi)全部樣本的同一維度為一組,計(jì)算均值和方差并進(jìn)行規(guī)范化,圖 11(左)11\;(左)11(左) 同箭頭跨越的部分為一組規(guī)范化對(duì)象;對(duì)于張量 batch 而言,BN 以 batch 內(nèi)全部序列,同一位置對(duì)應(yīng)單詞的同一維度為一組,計(jì)算均值和方差并進(jìn)行規(guī)范化,圖 11(右)11\;(右)11(右) 同箭頭跨越的部分為一組規(guī)范化對(duì)象。
我們知道 batch 的概念只出現(xiàn)在訓(xùn)練階段,因此,在訓(xùn)練時(shí)會(huì)保存每一個(gè) batch 計(jì)算的均值和方差,通過(guò)如下公式計(jì)算全局均值和方差作為推斷階段要使用的均值和方差。
μglobal=E[μi]σglobal2=mm?1E[σi2]\mu_{\rm global}= E[\mu_i]\\ \sigma^2_{\rm global}= \frac{m}{m-1} E[\sigma^2_i] μglobal?=E[μi?]σglobal2?=m?1m?E[σi2?]
其中 μi\mu_iμi? 和 σi2\sigma^2_iσi2? 為第 iii 個(gè) batch 對(duì)應(yīng)的均值和方差。需要注意,每組 batch 的每個(gè)維度都有一對(duì)均值和方差。
2.4.2. Layer Normalization
自然語(yǔ)言處理任務(wù)常用 Layer Normalization,而不用 Batch Normalization。這是因?yàn)?#xff0c;每個(gè)序列的長(zhǎng)度往往不一致,使用 BN 會(huì)出現(xiàn)一組內(nèi)進(jìn)行規(guī)范化的元素非常少的情況,此時(shí)在訓(xùn)練數(shù)據(jù)中計(jì)算出的均值和方差很可能與推斷數(shù)據(jù)中的均值和方差不匹配。另外,使用 BN 需要保證推斷序列長(zhǎng)度不超過(guò)訓(xùn)練序列最大長(zhǎng)度,否則會(huì)出現(xiàn)推斷序列多出的位置在訓(xùn)練階段沒(méi)有均值和方差被計(jì)算、被保存,也就無(wú)法在推斷階段對(duì)這些位置進(jìn)行 BN。
這兩點(diǎn)問(wèn)題都說(shuō)明 BN 在自然語(yǔ)言任務(wù)中的效果不令人滿意,因此引入 LN。LN 以每個(gè)單詞向量為一組進(jìn)行規(guī)范化,如圖 12(右)12\;(右)12(右) 所示。
圖 12????BN (左) 和 LN (右) 對(duì)比
如果使用 LN,那么不再需要保存任何訓(xùn)練階段的均值和方差,推斷階段可以根據(jù)輸入實(shí)時(shí)計(jì)算均值和方差。另外,自然語(yǔ)言處理任務(wù)可以使用 LN 的一個(gè)關(guān)鍵原因是,詞向量的每一維度之間不像身高、體重似的存在非常明顯的量綱,這保證了數(shù)據(jù)無(wú)需進(jìn)行 BN。
目前眾多的實(shí)驗(yàn)結(jié)果表明,BN 在 MLP 和 CNN 上表現(xiàn)優(yōu)異,但在 RNN 上效果不明顯。
【小筆記】在編寫對(duì)應(yīng)的代碼時(shí)發(fā)現(xiàn),Pytorch 中的 torch.nn.LayerNorm 函數(shù)是根據(jù)總體方差(標(biāo)準(zhǔn)差)進(jìn)行的層規(guī)范化,這與用 Numpy 中的 var 函數(shù)手動(dòng)實(shí)現(xiàn)層規(guī)范化的結(jié)果是一致的;而 Pytorch 中的 var 函數(shù)計(jì)算的是樣本方差,所以手動(dòng)實(shí)現(xiàn)出的結(jié)果與前面兩者不同。
2.5. 位置編碼
Transformer 模型完全基于注意力機(jī)制,而注意力機(jī)制無(wú)法捕獲有關(guān)時(shí)序(詞序)的信息,喪失時(shí)序信息可能使模型對(duì)語(yǔ)句的理解產(chǎn)生歧義。舉個(gè)例子,He is reading English. 和 Is he reading English ?,兩個(gè)語(yǔ)句由完全相同的單詞組成,詞序的不同使得這四個(gè)單詞產(chǎn)生了陳述句和疑問(wèn)句。這個(gè)例子的問(wèn)題可以通過(guò)對(duì)標(biāo)點(diǎn)符號(hào)編碼并與其他單詞一同輸入到模型中的方式解決。那么再看個(gè)例子,Even now he doesn’t believe me.、Now even he doesn’t believe me. 和 Now he doesn’t believe even me.,這三個(gè)語(yǔ)句由完全相同的單詞和標(biāo)點(diǎn)組成,卻表達(dá)了不同的含義,依次為「甚至到現(xiàn)在他還不相信我(其它時(shí)候就更不用說(shuō)了)」、「現(xiàn)在連他都不相信我了(其他人就更不會(huì)相信我了)」和「現(xiàn)在他甚至連我都不相信了(就更不會(huì)相信其他人了)」。由此可見(jiàn),時(shí)序信息對(duì)于模型理解語(yǔ)義是有幫助的。
為了解決這個(gè)問(wèn)題,需要向模型中注入單詞在序列中的絕對(duì)或者相對(duì)位置信息。Transformer 中第一個(gè)編碼器的輸入和第一個(gè)解碼器的輸入是由位置編碼(positional encodings)和詞嵌入(embedding)相加得到,因此,序列中第 iii 個(gè)單詞的表示為 xi=pi+eix_i=p_i+e_ixi?=pi?+ei?。顯然,需要保證 pip_ipi? 和 eie_iei? 具有相同的維度 dmodeld_{\rm model}dmodel?。
上面提到過(guò)詞嵌入 eie_iei? 可以由預(yù)訓(xùn)練模型生成,也可以隨機(jī)初始化跟隨 Transformer 模型學(xué)習(xí)得到;位置編碼同樣也可以跟隨模型學(xué)習(xí)確定,但是常用不同頻率的正弦和余弦函數(shù)對(duì)位置進(jìn)行固定編碼。經(jīng)證明,三角函數(shù)編碼與學(xué)習(xí)向量這兩種方法效果近似,但是三角函數(shù)編碼的優(yōu)勢(shì)在于,其使得模型能夠處理比訓(xùn)練階段遇到的序列長(zhǎng)度更長(zhǎng)的序列。
不同頻率的正弦函數(shù)和余弦函數(shù)定義如下:
PE(pos,2i)=sin?(pos/100002i/dmodel)PE(pos,2i+1)=cos?(pos/100002i/dmodel){\rm PE}(pos, 2i) = \sin (pos/10000^{2i/d_{\rm model}})\\ {\rm PE}(pos, 2i+1) = \cos (pos/10000^{2i/d_{\rm model}}) PE(pos,2i)=sin(pos/100002i/dmodel?)PE(pos,2i+1)=cos(pos/100002i/dmodel?)
其中,pospospos 表示單詞在序列中的位置,范圍 [0,seq_len?1][0,\rm seq\_len-1][0,seq_len?1];dmodeld_{\rm model}dmodel? 必須為偶數(shù),在代碼實(shí)現(xiàn)中取值為 512512512;iii 表示向量 PE\rm PEPE 某一維度的索引值除以 222 的商(除以 222 向下取整),范圍 [0,dmodel/2?1][0, d_{\rm model}/2-1][0,dmodel?/2?1]。pospospos 處的位置編碼向量 PEpos{\rm PE}_{pos}PEpos? 由兩個(gè)函數(shù)交替拼接得到:
PEpos=[sin?(w0?pos),cos?(w0?pos),…,sin?(wi?pos),cos?(wi?pos),…,sin?(wdmodel/2?1?pos),cos?(wdmodel/2?1?pos)]{\rm PE}_{pos} = [\sin(w_0·pos),\cos(w_0·pos),\dots, \sin(w_i·pos),\cos(w_i·pos),\dots, \sin(w_{{d_{\rm model}/2}-1}·pos),\cos(w_{{d_{\rm model}/2}-1}·pos)] PEpos?=[sin(w0??pos),cos(w0??pos),…,sin(wi??pos),cos(wi??pos),…,sin(wdmodel?/2?1??pos),cos(wdmodel?/2?1??pos)]
其中 wi=1/100002i/dmodelw_i = 1/10000^{2i/d_{\rm model}}wi?=1/100002i/dmodel? 。可見(jiàn),位置編碼的每個(gè)維度對(duì)應(yīng)于一個(gè)正弦(或余弦)曲線。這些波長(zhǎng)形成一個(gè)從 2π2\pi2π 到 10000?2π10000 \cdot 2\pi10000?2π 的集合級(jí)數(shù)。
這樣編碼不僅能夠提供單詞的絕對(duì)位置信息,還隱含著單詞的相對(duì)距離信息。討論 pos+kpos+kpos+k 處與 pospospos 處的位置編碼的關(guān)系。pos+kpos+kpos+k 處的位置編碼可以表示為:
PE(pos+k,2i)=sin?(wi?(pos+k))PE(pos+k,2i+1)=cos?(wi?(pos+k)){\rm PE}(pos+k, 2i) = \sin (w_i · (pos+k))\\ {\rm PE}(pos+k, 2i+1) = \cos (w_i · (pos+k)) PE(pos+k,2i)=sin(wi??(pos+k))PE(pos+k,2i+1)=cos(wi??(pos+k))
根據(jù)三角函數(shù)公式
sin?(α+β)=sin?α?cos?β+cos?α?sin?βsin?(α+β)=cos?α?cos?β?sin?α?sin?β\sin (\alpha+\beta) = \sin \alpha · \cos \beta + \cos \alpha ·\sin \beta \\ \sin (\alpha+\beta) = \cos \alpha · \cos \beta - \sin \alpha ·\sin \beta \\ sin(α+β)=sinα?cosβ+cosα?sinβsin(α+β)=cosα?cosβ?sinα?sinβ
可得
PE(pos+k,2i)=sin?(wi?(pos+k))=sin?(wi?pos)?cos?(wi?k)+cos?(wi?pos)?sin?(wi?k)PE(pos+k,2i+1)=cos?(wi?(pos+k))=cos?(wi?pos)?cos?(wi?k)?sin?(wi?pos)?sin?(wi?k){\rm PE}(pos+k, 2i) = \sin (w_i · (pos+k))= \sin (w_i ·pos) · \cos (w_i ·k) + \cos (w_i ·pos) ·\sin (w_i ·k)\\ {\rm PE}(pos+k, 2i+1) = \cos (w_i · (pos+k)) = \cos (w_i ·pos) · \cos (w_i ·k) - \sin (w_i ·pos) ·\sin (w_i · k) PE(pos+k,2i)=sin(wi??(pos+k))=sin(wi??pos)?cos(wi??k)+cos(wi??pos)?sin(wi??k)PE(pos+k,2i+1)=cos(wi??(pos+k))=cos(wi??pos)?cos(wi??k)?sin(wi??pos)?sin(wi??k)
代入 PE(pos+k,2i)=sin?(wi?pos){\rm PE}(pos+k, 2i) = \sin(w_i·pos)PE(pos+k,2i)=sin(wi??pos) 和 PE(pos+k,2i+1)=cos?(wi?pos){\rm PE}(pos+k, 2i+1)=\cos(w_i·pos)PE(pos+k,2i+1)=cos(wi??pos) 得
PE(pos+k,2i)=cos?(wi?k)?PE(pos,2i)+sin?(wi?k)?PE(pos,2i+1)PE(pos+k,2i+1)=cos?(wi?k)?PE(pos,2i+1)?sin?(wi?k)?PE(pos,2i){\rm PE}(pos+k, 2i) = \cos (w_i ·k)· {\rm PE}(pos,2i) + \sin (w_i ·k)·{\rm PE}(pos, 2i+1) \\ {\rm PE}(pos+k, 2i+1) = \cos (w_i ·k) · {\rm PE}(pos, 2i+1) - \sin (w_i · k)·{\rm PE}(pos, 2i) PE(pos+k,2i)=cos(wi??k)?PE(pos,2i)+sin(wi??k)?PE(pos,2i+1)PE(pos+k,2i+1)=cos(wi??k)?PE(pos,2i+1)?sin(wi??k)?PE(pos,2i)
kkk 是常數(shù),不妨將有關(guān) kkk 是三角函數(shù)記為
u=cos?(wi?k)v=sin?(wi?k)u = \cos(w_i · k) \\ v = \sin(w_i · k) u=cos(wi??k)v=sin(wi??k)
這樣,PE(pos+k,2i){\rm PE}(pos+k, 2i)PE(pos+k,2i) 和 PE(pos+k,2i+1){\rm PE}(pos+k, 2i+1)PE(pos+k,2i+1) 可以以矩陣的形式表達(dá):
[PE(pos+k,2i)PE(pos+k,2i+1)]=[uv?vu]×[PE(pos,2i)PE(pos,2i+1)]\left[ \begin{matrix} {\rm PE}(pos+k, 2i) \\ {\rm PE}(pos+k, 2i+1) \\ \end{matrix} \right]= \left[ \begin{matrix} u & v\\ -v & u \end{matrix} \right] \times \left[ \begin{matrix} {\rm PE}(pos, 2i) \\ {\rm PE}(pos, 2i+1) \\ \end{matrix} \right] [PE(pos+k,2i)PE(pos+k,2i+1)?]=[u?v?vu?]×[PE(pos,2i)PE(pos,2i+1)?]
從矩陣乘法的角度可以看出,pos+kpos+kpos+k 處與 pospospos 處位置編碼存在線性關(guān)系。進(jìn)一步,計(jì)算兩個(gè)位置編碼向量的內(nèi)積:
<PE(pos),PE(pos+k)>=∑isin?(wi?pos)sin?(wi?(pos+k))+cos?(wi?pos)cos?(wi?(pos+k))=∑icos?(wi?(pos?(pos+k)))=∑icos?(wi?k)\begin{align} \left<{\rm PE}(pos),{\rm PE}(pos+k)\right> &= \sum_{i} \sin(w_i·pos)\sin(w_i·(pos+k)) + \cos(w_i·pos)\cos(w_i·(pos+k)) \notag\\ &= \sum_i \cos(w_i·(pos - (pos+k)))\notag \\ &= \sum_i \cos(w_i·k) \notag \end{align} ?PE(pos),PE(pos+k)??=i∑?sin(wi??pos)sin(wi??(pos+k))+cos(wi??pos)cos(wi??(pos+k))=i∑?cos(wi??(pos?(pos+k)))=i∑?cos(wi??k)?
可以看出,如果兩個(gè)位置距離越遠(yuǎn),那么對(duì)應(yīng)的編碼向量?jī)?nèi)積就越小。
可以證明位置編碼和詞嵌入相加而不是拼接的合理性,證明見(jiàn) REF[5] 李宏毅視頻 30:00 處。
2.6. 掩碼機(jī)制
Mask 表示掩碼,它對(duì)某些值進(jìn)行掩蓋,使其在參數(shù)更新時(shí)不產(chǎn)生效果。Transformer 模型里面涉及兩種 mask,分別是 Padding Mask 和 Sequence Mask。Padding Mask 在所有的注意力(Attention)模塊里面都需要用到,而 Sequence Mask 只在解碼器的自注意力(Self-Attention)模塊里面用到。
Padding Mask
為了保證同一 batch 內(nèi)全部輸入序列(輸入到編碼器和解碼器的序列)的長(zhǎng)度一致,我們會(huì)設(shè)定一個(gè)固定長(zhǎng)度,過(guò)長(zhǎng)的序列會(huì)被截?cái)?#xff0c;多余部分被丟棄;在長(zhǎng)度不足的輸入序列后面填充 padding。Padding Mask 是一個(gè)二值掩碼,用于區(qū)分序列中每個(gè)位置是否為 padding 區(qū)域。因?yàn)樘畛?padding 的位置沒(méi)有任何意義,只是為了方便處理,所以注意力機(jī)制不應(yīng)該把注意力放在這些位置。具體來(lái)說(shuō),在計(jì)算注意力時(shí),進(jìn)行 Softmax 操作之前需要根據(jù) Padding Mask 將 padding 區(qū)域的值設(shè)置為負(fù)無(wú)窮,這樣經(jīng)過(guò)歸一化后,padding 區(qū)域的概率會(huì)接近 000,即目標(biāo)單詞不關(guān)注 padding 區(qū)域內(nèi)的單詞。
盡管引入 padding 后可以保證輸入序列長(zhǎng)度一致,但是使用 Batch Normalization 依舊沒(méi)有意義,這是因?yàn)榧尤氲?padding 本身沒(méi)有意義,只是起到占位符的作用,不改變序列長(zhǎng)度不相等的本質(zhì)。
在代碼實(shí)現(xiàn)時(shí),每個(gè)序列對(duì)應(yīng)自己的 Padding Mask,Padding Mask 是維度為序列長(zhǎng)度 seq_len\rm seq\_lenseq_len 的二值向量。Padding Mask 對(duì)大小為 seq_len×seq_len\rm seq\_len\times seq\_lenseq_len×seq_len 的矩陣 QKT/dkQK^T/\sqrt{d_k}QKT/dk?? 按行遮蓋,設(shè)置被遮蓋區(qū)域的值為負(fù)無(wú)窮,如圖 13(左)13\;(左)13(左) 所示。
注意到,并沒(méi)有遮蓋 <pad> 對(duì)應(yīng)的行,這是因?yàn)闅w一化操作以行為單位進(jìn)行的,<pad> 對(duì)應(yīng)的行表示 <pad> 對(duì)序列中單詞的關(guān)注度,顯然,我們不會(huì)在意這些關(guān)注度,所以 <pad> 對(duì)應(yīng)的行是否計(jì)算都無(wú)所謂。
Sequence Mask
Sequence Mask 只出現(xiàn)在解碼器中的自注意力模塊中,這主要是考慮到自注意力機(jī)制會(huì)提取全局信息,也就是說(shuō)在預(yù)測(cè)位置 pospospos 的單詞時(shí)用到序列中的全部單詞信息,包括 pospospos 之前(歷史)和之后(未來(lái))的單詞,但是,顯然在預(yù)測(cè)時(shí)用到未來(lái)信息是一種作弊行為,屬于信息泄露。在實(shí)際推斷中,不可能提前知道未來(lái)信息。因此,引入 Sequence Mask 對(duì)未來(lái)信息進(jìn)行遮蓋。具體來(lái)說(shuō),作用時(shí)機(jī)與 Padding Mask 一致,二者共同影響矩陣 QKT/dkQK^T/\sqrt{d_k}QKT/dk??,被遮蓋的區(qū)域的值設(shè)置為負(fù)無(wú)窮。
在代碼實(shí)現(xiàn)時(shí),一個(gè) batch 內(nèi)序列的 Sequence Mask 是完全一樣的,Sequence Mask 是大小為 seq_len×seq_len\rm seq\_len \times seq\_lenseq_len×seq_len 的二值上三角(不含對(duì)角線)矩陣,如圖 13(中)13\;(中)13(中) 所示。一般會(huì)先將 Padding Mask 和 Sequence Mask 合并成一個(gè)掩碼矩陣,再對(duì) QKT/dkQK^T/\sqrt{d_k}QKT/dk?? 處理,如圖 13(右)13\;(右)13(右) 所示。
圖 13????Padding Mask 操作過(guò)程(左)、Sequence Mask 操作過(guò)程(中) 和合并 Mask(右)
2.7. 線性映射和歸一化
這里講解的線性映射和歸一化是圖 444 最后一層解碼器連接的 Linear 層和 Softmax 層。這一步的目的在于將解碼器輸出的每個(gè)單詞向量映射到 vocab_size\rm vocab\_sizevocab_size 維,對(duì) vocab_size\rm vocab\_sizevocab_size 維歸一化,最高概率值對(duì)應(yīng)該單詞的預(yù)測(cè)標(biāo)簽。其中 vocab_size\rm vocab\_sizevocab_size 為模型的輸出詞匯表大小,即模型的輸出單詞肯定包含在詞匯表中。如圖 141414 所示。
圖 14????線性映射與歸一化
3. 訓(xùn)練與推斷
以漢(機(jī)器 學(xué)習(xí))譯英(machine learning)的機(jī)器翻譯任務(wù)為例講解訓(xùn)練階段和推斷階段的整體流程。
訓(xùn)練階段。batch 內(nèi)的每個(gè)樣本對(duì)應(yīng)三個(gè)序列:源序列、目標(biāo)序列和預(yù)測(cè)序列。源序列是輸入到編碼器中的序列,即 機(jī)器 學(xué)習(xí);目標(biāo)序列是輸入到解碼器中的序列,需要在序列前加上起始標(biāo)志 <BOS>,即 <BOS> machine learning;預(yù)測(cè)序列是模型對(duì)于輸入為源序列和目標(biāo)序列的預(yù)測(cè)結(jié)果,也可以認(rèn)為是最后一層解碼器的輸出經(jīng)過(guò)線性映射和歸一化的結(jié)果,預(yù)測(cè)序列可能與目標(biāo)序列不一致,比如預(yù)測(cè)出 machine translation <EOS> 或 love learning <EOS>,但是預(yù)測(cè)出的序列一定是以結(jié)束標(biāo)志 EOS 結(jié)尾。計(jì)算 batch 內(nèi)全部樣本的目標(biāo)序列中單詞的獨(dú)熱編碼與預(yù)測(cè)序列中對(duì)應(yīng)單詞的概率分布的交叉熵,定義損失函數(shù)為交叉熵之和,進(jìn)而反向傳播更新參數(shù)。圖 151515 展示了單樣本訓(xùn)練過(guò)程。
圖 15????單樣本訓(xùn)練過(guò)程
推斷階段。在推斷階段,目標(biāo)序列是未知的,所以不能像訓(xùn)練階段一樣并行輸入計(jì)算。模型根據(jù)當(dāng)前的目標(biāo)序列預(yù)測(cè)出下一個(gè)單詞,將預(yù)測(cè)出的單詞拼接到目標(biāo)序列上作為新的目標(biāo)序列輸入到模型中,繼續(xù)預(yù)測(cè)下一個(gè)單詞,直至預(yù)測(cè)到結(jié)束標(biāo)志 <EOS>,初始目標(biāo)序列僅由起始標(biāo)志 <BOS> 構(gòu)成。圖 444 中“shifted right”表達(dá)的正是這種串行的推斷方式。圖 161616 展示了單樣本推斷過(guò)程。
圖 16????單樣本推斷過(guò)程
注意兩點(diǎn):
REF
[1] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30.
[2] The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time. (jalammar.github.io)
[3] 68 Transformer【動(dòng)手學(xué)深度學(xué)習(xí)v2】- bilibili
[4] Transformer論文逐段精讀【論文精讀】- bilibili
[5] 李宏毅Transformer - bilibili
[6] 深入理解transformer源碼_- CSDN
[7] 自然語(yǔ)言處理(十五):Transformer介紹_自然語(yǔ)言處理transformer - CSDN
[8] 《The Annotated Transformer》翻譯——注釋和代碼實(shí)現(xiàn)《Attention Is All You Need》- CSDN
[9] 【自然語(yǔ)言處理】Attention 講解 - CSDN
[10] 超詳細(xì)圖解Self-Attention - 知乎
[11] 深度學(xué)習(xí)中的batch normalization - 知乎
[12] 層標(biāo)準(zhǔn)化詳解(Layer Normalization)- CSDN
[13] 【深度學(xué)習(xí)】Layer Normalization - CSDN
[14] 標(biāo)準(zhǔn)化、歸一化、規(guī)范化區(qū)別_- CSDN
[15] 什么?是Transformer位置編碼 - CSDN
[16] 位置編碼詳細(xì)解讀 - bilibili
[17] 碎碎念:Transformer的細(xì)枝末節(jié) - 知乎
[18] Properties of Dot Product of Random Vectors - 知乎
[19] Transformer學(xué)習(xí)筆記一:Positional Encoding(位置編碼) - 知乎
[20] Transformer 模型詳解_- CSDN
[21] 從訓(xùn)練和預(yù)測(cè)角度來(lái)理解Transformer中Masked Self-Attention的原理 - 知乎
[22] Transformer詳解 - 知乎
[23] transformer模型學(xué)習(xí)路線_transformer怎么訓(xùn)練 - CSDN
[24] ml6 - 博客園
總結(jié)
以上是生活随笔為你收集整理的【自然语言处理】Transformer 讲解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 基于python的Nginx日志管理分析
- 下一篇: mysql配置kodi16.1_kodi