Transformer结构详解(有图,有细节)
文章目錄
- 1. transformer的基本結(jié)構(gòu)
- 2. 模塊詳解
-
- 2.1 模塊1:Positional Embedding
- 2.2 模塊2:Multi-Head Attention
-
- 2.2.1 Scaled Dot-Product Attention
- 2.2.2 Multi-Head
- 2.3 模塊3:ADD
- 2.4 模塊4:Layer Normalization
- 2.5 模塊5:Feed Forward NetWork
- 2.6 模塊6:Masked Multi-Head Attention
- 2.7 模塊7: Multi-Head Attention
- 2.8 模塊8:Linear
- 2.9 模塊9:SoftMax
- 3. transformer在機(jī)器翻譯任務(wù)中的使用
- 4 transformer 相關(guān)的其它問題
?
1. transformer的基本結(jié)構(gòu)
2. 模塊詳解
2.1 模塊1:Positional Embedding
??P E PEPE模塊的主要做用是把位置信息加入到輸入向量中,使模型知道每個(gè)字的位置信息。對(duì)于每個(gè)位置的P E PEPE是固定的,不會(huì)因?yàn)檩斎氲木渥硬煌煌?#xff0c;且每個(gè)位置的P E PEPE大小為1 ? n 1 *n1?n(n為word embedding 的dim size),transformer中使用正余弦波來計(jì)算P E PEPE,具體如下:
P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}})PE(pos,2i)?=sin(pos/100002i/dmodel?)PE(pos,2i+1)?=cos(pos/100002i/dmodel?)
- p o s pospos代表的是一個(gè)字在句子中的位置,從0到名字長(zhǎng)度減1,是下圖中紅色的序號(hào)。
- i ii代表的是dim 的序號(hào),是下圖中藍(lán)色的序號(hào):
- 當(dāng)i ii為偶數(shù)時(shí),此位置的值使用?s i n ( p o s / 1000 0 2 i / d m o d e l ) sin(pos/10000^{2i/d_{model}})sin(pos/100002i/dmodel?)來填充。
- 當(dāng)i ii為奇數(shù)時(shí),些位置的值使用?c o s ( p o s / 1000 0 2 i / d m o d e l ) cos(pos/10000^{2i/d_{model}})cos(pos/100002i/dmodel?)來填充
實(shí)現(xiàn)代碼:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
至于為什么選擇這種方式,論文中給出的解釋是:
理解:
由s i n ( α + β ) = s i n α c o s β + s i n β c o s α c o s ( α + β ) = c o s α c o s β ? s i n β s i n α sin(\alpha+\beta)=sin\alpha cos\beta + sin\beta cos\alpha\\ cos(\alpha+\beta)=cos\alpha cos\beta - sin\beta sin\alphasin(α+β)=sinαcosβ+sinβcosαcos(α+β)=cosαcosβ?sinβsinα
推出:
P E ( p o s + k , 2 i ) = s i n ( ( p o s + k ) / 1000 0 2 i / d m o d e l ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) + s i n ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) = P E ( p o s , 2 i ) P E ( k , 2 i + 1 ) ? P E ( p o s , 2 i + 1 ) P E ( k , 2 i )
PE(pos+k,2i)=sin((pos+k)/100002i/dmodel)=sin(pos/100002i/dmodel)cos(k/100002i/dmodel)+sin(pos/100002i/dmodel)cos(k/100002i/dmodel)=PE(pos,2i)PE(k,2i+1)?PE(pos,2i+1)PE(k,2i)PE(pos+k,2i)=sin((pos+k)/100002i/dmodel)=sin(pos/100002i/dmodel)cos(k/100002i/dmodel)+sin(pos/100002i/dmodel)cos(k/100002i/dmodel)=PE(pos,2i)PE(k,2i+1)?PE(pos,2i+1)PE(k,2i)
PE(pos+k,2i)?=sin((pos+k)/100002i/dmodel?)=sin(pos/100002i/dmodel?)cos(k/100002i/dmodel?)+sin(pos/100002i/dmodel?)cos(k/100002i/dmodel?)=PE(pos,2i)PE(k,2i+1)?PE(pos,2i+1)PE(k,2i)?P E ( p o s + k , 2 i + 1 ) = c o s ( ( p o s + k ) / 1000 0 2 i / d m o d e l ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) ? s i n ( p o s / 1000 0 2 i / d m o d e l ) s i n ( k / 1000 0 2 i / d m o d e l ) = P E ( p o s , 2 i + 1 ) P E ( k , 2 i + 1 ) ? P E ( p o s , 2 i ) P E ( k , 2 i )
PE(pos+k,2i+1)=cos((pos+k)/100002i/dmodel)=cos(pos/100002i/dmodel)cos(k/100002i/dmodel)?sin(pos/100002i/dmodel)sin(k/100002i/dmodel)=PE(pos,2i+1)PE(k,2i+1)?PE(pos,2i)PE(k,2i)PE(pos+k,2i+1)=cos((pos+k)/100002i/dmodel)=cos(pos/100002i/dmodel)cos(k/100002i/dmodel)?sin(pos/100002i/dmodel)sin(k/100002i/dmodel)=PE(pos,2i+1)PE(k,2i+1)?PE(pos,2i)PE(k,2i)
PE(pos+k,2i+1)?=cos((pos+k)/100002i/dmodel?)=cos(pos/100002i/dmodel?)cos(k/100002i/dmodel?)?sin(pos/100002i/dmodel?)sin(k/100002i/dmodel?)=PE(pos,2i+1)PE(k,2i+1)?PE(pos,2i)PE(k,2i)?以P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) P E ( k , 2 i + 1 ) ? P E ( p o s , 2 i + 1 ) P E ( k , 2 i ) PE(pos+k,2i)=PE(pos,2i)PE(k,2i+1)-PE(pos,2i+1)PE(k,2i)PE(pos+k,2i)=PE(pos,2i)PE(k,2i+1)?PE(pos,2i+1)PE(k,2i)為例,當(dāng)k kk確定時(shí):?P E ( k , 2 i + 1 ) PE(k,2i+1)PE(k,2i+1)、P E ( k , 2 i ) PE(k,2i)PE(k,2i)均為常數(shù),P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) ? 常 數(shù) 2 i + 1 k ? P E ( p o s , 2 i + 1 ) ? 常 數(shù) i k PE(pos+k,2i)=PE(pos,2i) * 常數(shù)_{2i+1}^k - PE(pos,2i+1) * 常數(shù)_{i}^kPE(pos+k,2i)=PE(pos,2i)?常數(shù)2i+1k??PE(pos,2i+1)?常數(shù)ik?
上式為即為1)中所說的線性函數(shù)。我們知道,每個(gè)位置(pos)的PE值均不同,因此我們可以根據(jù)PE的值區(qū)分位置,而由上面的線性函數(shù),我們可以計(jì)量出兩個(gè)位置的相對(duì)距離。
理解:
??第二點(diǎn)很好理解就是說了下正弦波的優(yōu)點(diǎn)。這里我著重講下正弦波存在的問題。在transformer架構(gòu)里,我們計(jì)算兩個(gè)特征的關(guān)系用的是點(diǎn)積的的方式(因?yàn)槭褂昧薉ot-Product Attention)。所以兩個(gè)PE的關(guān)系(距離)實(shí)際是以它們的點(diǎn)積來表示的。舉例如下[ 1 ] ^{[1]}[1]:
我們令c i = 1 / 1000 0 2 i / d m o d e l c_i=1/10000^{2i/d_{model}}ci?=1/100002i/dmodel?,則第t tt及t + 1 t+1t+1個(gè)位置的positional embedding 是:
P E t = [ s i n ( c 0 t ) c o s ( c 0 t ) s i n ( c 1 t ) c o s ( c 1 t ) ? s i n ( c d 2 ? 1 t ) c o s ( c d 2 ? 1 t ) ] T PE_t={\left[ {
sin(c0t)cos(c0t)sin(c1t)cos(c1t)?sin(cd2?1t)cos(cd2?1t)sin(c0t)cos(c0t)sin(c1t)cos(c1t)?sin(cd2?1t)cos(cd2?1t)
} \right]^T}PEt?=????????????sin(c0?t)cos(c0?t)sin(c1?t)cos(c1?t)?sin(c2d??1?t)cos(c2d??1?t)?????????????TP E t + k = [ s i n ( c 0 ( t + k ) ) c o s ( c 0 ( t + k ) ) s i n ( c 1 ( t + k ) ) c o s ( c 1 ( t + k ) ) ? s i n ( c d 2 ? 1 ( t + k ) ) c o s ( c d 2 ? 1 ( t + k ) ) ] T PE_{t+k}={\left[ {
sin(c0(t+k))cos(c0(t+k))sin(c1(t+k))cos(c1(t+k))?sin(cd2?1(t+k))cos(cd2?1(t+k))sin(c0(t+k))cos(c0(t+k))sin(c1(t+k))cos(c1(t+k))?sin(cd2?1(t+k))cos(cd2?1(t+k))
} \right]^T}PEt+k?=????????????sin(c0?(t+k))cos(c0?(t+k))sin(c1?(t+k))cos(c1?(t+k))?sin(c2d??1?(t+k))cos(c2d??1?(t+k))?????????????T則:P E t P E t + k = Σ j = 0 d 2 [ s i n ( c j t ) s i n ( c j ( t + k ) + c o s ( c j t ) c o s ( c j ( t + k ) ] = Σ j = 0 d 2 c o s ( c j ( t ? ( t + k ) ) = Σ j = 0 d 2 c o s ( c j k )
PEtPEt+k=Σd2j=0[sin(cjt)sin(cj(t+k)+cos(cjt)cos(cj(t+k)]=Σd2j=0cos(cj(t?(t+k))=Σd2j=0cos(cjk)PEtPEt+k=Σj=0d2[sin(cjt)sin(cj(t+k)+cos(cjt)cos(cj(t+k)]=Σj=0d2cos(cj(t?(t+k))=Σj=0d2cos(cjk)
PEt?PEt+k??=Σj=02d??[sin(cj?t)sin(cj?(t+k)+cos(cj?t)cos(cj?(t+k)]=Σj=02d??cos(cj?(t?(t+k))=Σj=02d??cos(cj?k)???上式的第二行是使用了?c o s ( α ? β ) = s i n α s i n β + c o s α c o s β cos(\alpha-\beta)=sin\alpha sin\beta + cos\alpha cos\betacos(α?β)=sinαsinβ+cosαcosβ?這個(gè)公式進(jìn)行的變換。從最終的結(jié)果我們可以看出,兩個(gè)embedding的距離度量只與間隔k kk有關(guān),而c o s coscos函數(shù)關(guān)于y軸對(duì)稱,即c o s x = c o s ( ? x ) cosx=cos(-x)cosx=cos(?x),所以,P E t P E t + k PE_tPE_{t+k}PEt?PEt+k?的度量只與k kk的大小有關(guān),與誰在前,誰在后無關(guān)。即,經(jīng)過dot-attention機(jī)制后,我們把positional embedding中的順序信息丟失了。所以,從這方面看,正弦波這種位置PE并不太適合在用在transformer結(jié)構(gòu)中,這也可能是后面的bert,t5都采用的基于學(xué)習(xí)的positional embedding。(注:模塊3會(huì)把順序信息傳遞下去,但我們還是在算法的核心處理上丟失了信息。)
2.2 模塊2:Multi-Head Attention
??這個(gè)模塊是transformer的核心,我們把這塊拆成兩部分來理解,先講下其中的Scaled Dot-Product Attention(縮放的點(diǎn)積注意力機(jī)制),再講Multi-Head。
2.2.1 Scaled Dot-Product Attention
??我們先看下論文中的 Scaled Dot-Product Attention 步驟,如下圖:
下面我們對(duì)著上面的圖講一下,具體的看下每步做了什么。
由于linear的輸入和輸出均為d m o d e l d_{model}dmodel?,所以Q,K,V的大小和input_sum的大小是一致的。
MatMul: 這步是實(shí)際是計(jì)算的?Q ? K T Q*K^TQ?KT, 如下圖:
從上圖可以看出Q ? K T Q*K^TQ?KT的結(jié)果s c o r e s scoresscores是一個(gè)L ? L L*LL?L的矩陣(L為句字長(zhǎng)度),其中scores中的[ i , j ] [i,j][i,j]位置表示的是Q QQ中的第i ii行的字和K T K^TKT中第j jj列的相似度(也可以說是重要度,我們可以這么理解,在機(jī)器翻譯任務(wù)中,當(dāng)我們翻譯一句話的第i ii個(gè)字的的時(shí)候,我們要考慮原文中哪個(gè)位置的字對(duì)我們現(xiàn)在要翻譯的這個(gè)位置的字的影響最大)。
Scale :這部分就是對(duì)上面的s c o r e s scoresscores進(jìn)行了個(gè)類似正則化的操作。
s c o r e s = s c o r e s d q scores=\frac{scores}{\sqrt{d_q}}scores=dq??scores??(這里要說一下d q d_{q}dq?,論文中給出的是d h d_{h}dh?,即d m o d e l / h d_{model}/hdmodel?/h, 因?yàn)檎撐闹凶隽薽ulti-head,所以?d q = d h d_q=d_{h}dq?=dh?),這里解釋下除以d q \sqrt{d_q}dq??的原因,原文是這樣說的:“我們認(rèn)為對(duì)于大的d k d_kdk?,點(diǎn)積在數(shù)量級(jí)上增長(zhǎng)的幅度大,將softmax函數(shù)推向具有極小梯度的區(qū)域4 ^44。為了抵消這種影響,我們對(duì)點(diǎn)積擴(kuò)展1 d k \frac{1}{\sqrt{d_k}}dk??1?倍”。
Mask: 這步使用一個(gè)很小的值,對(duì)指定位置進(jìn)行覆蓋填充。這樣,在之后計(jì)算softmax時(shí),由于我們填充的值很小,所以計(jì)算出的概率也會(huì)很小,基本就忽略了。(如果不填個(gè)很小的值的話,后面我們計(jì)算softmax時(shí),e x i ∑ i = 1 k e x i \frac{e^{x_i}}{\sum_{i=1}^{k}{e^{x_i}}}∑i=1k?exi?exi???,當(dāng)x = 0 x=0x=0時(shí)(padding的值),分子e 0 = 1 e^{0}=1e0=1這可不是一個(gè)很小的值。),mask操作在encoder和decoder過程中都存在,在encoder中我們是對(duì)padding的值進(jìn)行mask,在decoder中我們主要是為了不讓前面的詞在翻譯時(shí)看到未來的詞,所以對(duì)當(dāng)前詞之后的詞的信息進(jìn)行mask。下面我們先看看encoder中關(guān)于padding的mask是怎么做的。
??如上圖,輸入中有兩個(gè)pad字符,s c o r e s scoresscores中的x都是pad參與計(jì)算產(chǎn)生的,我們?yōu)榱伺懦齪ad產(chǎn)生的影響,我們提供了如圖的mask,我們把scores與mask的位置一一對(duì)應(yīng),如果mask的值為0,則scores的對(duì)應(yīng)位置填充一個(gè)非常小的負(fù)數(shù)(例如:? e 9 -e^9?e9)。最終得到的是上圖最后一個(gè)表格。說了這么多,其實(shí)在pytorch中就一句話。
- 1
附上代碼:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
2.2.2 Multi-Head
??這里我們看看multi-head attention中的 multi-head是什么意思。我們假設(shè)d m o d e l = 512 d_{model}=512dmodel?=512,h = 8 h=8h=8(8個(gè)頭),說下transformer中是怎么處理的:
??前面我們說過了,Q QQ、K KK、V VV三個(gè)矩陣是encoder的輸入經(jīng)過三個(gè)linear映射而成,它們的大小是[ B , L , D ] [B,L,D][B,L,D](batch size, max sentence length, embedding size), 這里為了說的清楚些,我們暫時(shí)不看[ B ] [B][B]這個(gè)維度。那么Q QQ、K KK、V VV的維度都為[ L , D ] [L,D][L,D],multi-head就是在[ D ] [D][D]維度上對(duì)數(shù)據(jù)進(jìn)行切割,把數(shù)據(jù)切成等長(zhǎng)的8段(h = 8 h=8h=8),這樣Q QQ、K KK、V VV均被切成等長(zhǎng)的8段,然后對(duì)應(yīng)的Q QQ、K KK、V VV子段組成一組,每組通過 Scaled Dot-Product Attention 算法 計(jì)算出結(jié)果,這樣的結(jié)果我們會(huì)得到8個(gè),然后把這8個(gè)結(jié)果再拼成一個(gè)結(jié)果,就multi-head的結(jié)果。具體過程如下圖:
2.3 模塊3:ADD
??此模塊做了個(gè)類似殘差的操作,不同的是不是用輸入減去輸出,而是用輸入加上輸出。(指Multi-Head Attention模塊的輸入和輸出),具體操作就是把模塊2的輸入矩陣與模塊2的輸入矩陣的對(duì)應(yīng)位置做加法運(yùn)算。
2.4 模塊4:Layer Normalization
??不論是layer normalization還是batch normalization,其實(shí)做的都是一件事情,都是根據(jù)?x = a ? x ? x  ̄ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + bx=a?std+epsx?x?+b對(duì)x xx的分布進(jìn)行調(diào)整。不同的是x  ̄ \overline{x}x和s t d stdstd的計(jì)算方式不同。如下圖:
??batch normalization的x  ̄ \overline{x}x和s t d stdstd是延粉色方向計(jì)算的,而layer normalization是延藍(lán)色方向計(jì)算的。如果兄弟們?nèi)ッ嬖?#xff0c;可能面試官會(huì)問為什么這里沒有使用BN,而使用了LN,我的理解是,BN對(duì)batch size的大小是有要求的,一般batch size越大,計(jì)算出的x  ̄ \overline{x}x越好,而我用12G內(nèi)存的GPU,跑transformer的模型時(shí),batch size最多也就設(shè)置到32。batch size還是偏小的。所以使用與batch size無關(guān)的layer normlization。從另一個(gè)角度講,batch size之所以小,是因?yàn)槲覀兊膃mbedding size 大,而layer normalization 正好是延這個(gè)方向做的,所以正好使得layer normalization計(jì)算的更穩(wěn)定。
2.5 模塊5:Feed Forward NetWork
??Feed Forward NetWork 翻譯成中文叫 前饋網(wǎng)絡(luò),其實(shí)就是MLP。我們這里不糾結(jié)于FFN的定義,我們直接看下transformer里是怎么實(shí)現(xiàn)的。如下圖,我們先把輸入向量從512維(d m o d e l d_{model}dmodel?)映射到2048維,然后再映射到512維。實(shí)現(xiàn)時(shí),就是使用兩個(gè)linear層,第一個(gè)linear的輸入是512維,輸出是2048維,第二個(gè)linear的輸入是2048,輸出是512。
2.6 模塊6:Masked Multi-Head Attention
??上文已講了Multi-Head Attention,而且在講 Scaled Dot-Product Attention 時(shí)也講了mask機(jī)制,此模塊的區(qū)別在于maked的策略不同,在encoder中我們是把padding給masked掉,這里我們除了要考慮padding,還要考慮預(yù)測(cè)時(shí)的未來變量問題,換句話說,我們是用一句話中的前N ? 1 N-1N?1個(gè)字預(yù)測(cè)第N NN個(gè)字,那么我們?cè)陬A(yù)測(cè)第N NN個(gè)字時(shí),就不能讓模型看到第N個(gè)字之后的信息,所以這里們把預(yù)測(cè)第N NN個(gè)字時(shí),第N NN(包括)個(gè)字之后的字都masked掉。我們假設(shè)預(yù)測(cè)序列為’i like this apple’,則我們要做如下的mask。
2.7 模塊7: Multi-Head Attention
??模塊7 與上文 模塊2(encoder 中 的 Multi-Head Attention) 代碼實(shí)現(xiàn)上完全相同,區(qū)別再于模塊2 只有一個(gè)輸入,模塊2把此輸入經(jīng)過三個(gè)linear映射成Q QQ、K KK、V VV?, 而模塊7的輸入有兩個(gè),一個(gè)是decoder的輸入經(jīng)過第一個(gè)大模塊傳過來的值(為了方便,我們叫它input_x),一個(gè)是encoder最終結(jié)果(我們暫叫它input_memory), 模塊7是把input_x通過一個(gè)linear映射成了Q QQ,然后通過兩個(gè)linear把input_memory映射成K KK、V VV?,其它的與模塊2完全一致。
2.8 模塊8:Linear
??此模塊的目的是把模型的輸transformer decoder的輸出從d m o d e l d_{model}dmodel?維度映射到詞表大小的維度。linear本身也比較簡(jiǎn)單,這里不再細(xì)講了。
2.9 模塊9:SoftMax
??此模塊會(huì)把上層linear的輸出轉(zhuǎn)化成概率,對(duì)應(yīng)到某個(gè)字的概率。
3. transformer在機(jī)器翻譯任務(wù)中的使用
??在《Attention is All You Need》這篇文章中,是把transformer做為一個(gè)特征提取器放在一個(gè)Encoder-Decoder(下文用Encoder-Stack和Decoder-Stack,用以和transformer的encoder, decoder區(qū)分)架構(gòu)中的,具體細(xì)節(jié)見下圖:
??上面的圖片把整個(gè)結(jié)構(gòu)基本都畫出來了,這里再說下訓(xùn)練時(shí)的數(shù)據(jù)走向及流程:
1) 數(shù)據(jù)X XX?輸入到Encoder-Stack中,得到輸出變量e n c o d e r _ o u t p u t encoder\_outputencoder_output
2)?e n c o d e r _ o u t p u t encoder\_outputencoder_output?做為K e y KeyKey和V a l u e ValueValue的原始輸入 輸入到Decoder-Stack中,Decoder-Stack的Query為上一輪Decoder-Stack的輸出。
具體流程見下圖:
??這里我提一下decoder stack的輸入(上圖中的Query),前面說過了,在transformer中,decoder的核心思想是用一個(gè)句子中的前N ? 1 N-1N?1個(gè)字,預(yù)測(cè)第N NN個(gè)字,但在預(yù)測(cè)第一個(gè)字的時(shí)候,前面沒有字,這時(shí)我們可以在每句話前面加上一個(gè)固定的開始標(biāo)志(bos), 這樣相當(dāng)于把整個(gè)句子右移了一位。
4 transformer 相關(guān)的其它問題
??這部分我是想寫寫transformer的并行等其它問題,但今天寫的太累了,主要的也都寫完了,就先發(fā)了。
References
[ 1 ] [1][1]?https://zhuanlan.zhihu.com/p/166244505
總結(jié)
以上是生活随笔為你收集整理的Transformer结构详解(有图,有细节)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Transformer 详解
- 下一篇: Transformer 模型详解