pytorch笔记:09)Attention机制
剛從圖像處理的hole中攀爬出來,剛走一步竟掉到了另一個hole(fire in the hole*▽*)
1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的輸入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相對于傳統的encoder-decoder模型,attention機制僅在decoder處有所不同。下面具體看看:
1>保存了rnn每個詞向量對應隱藏層的輸出狀態(encoder_outputs),用于decoder的attention機制
#train代碼部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.輸入的input經過embed
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.獲取關于輸入的attention權重,這里的Q=decoder_rnn的input,K=decoder_rnn的隱藏元
2.1求Q和K相似度的方法有很多,這里讓全連接層自己來學習,把embedded和hidden連接在一起經過fc層(部分修改了下)
similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 經過softmax獲得歸一化的權重
attn_weights = F.softmax(similarity, dim=1)
1
3.權重應用于encoder輸出的所有詞對應的詞向量上(對應相乘即可)->獲得attention結果
attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention結果和decoder的輸入cat在一起,使用1個全連接層來融合二者,最終生成帶注意力機制的詞向量
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根據decoder的上一個輸出單詞來預測下一個單詞,這里多插一句,decoder的首個輸入為起始標志符’sos’,其根據encode最后的隱藏元來預測第一個單詞,后面依次類推。
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸氣標題),pytorch代碼推薦2篇:
哈佛大學NLP研究組:http://nlp.seas.harvard.edu/2018/04/03/attention.html
臺灣小哥的代碼(較通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:
下面以soft_attention為例(*input和output的attention,僅和self_attention做下區分,第1篇代碼標記src_attn,第2篇代碼標記dec_enc_attn),soft_attention的目標:給定序列Q(query,長度記為lq,維度dk),鍵序列K(key,長度記為lk,維度dk),值序列V(value,長度記為lv,維度dv),計算Q和K的相似度權重,最后再乘上V。
下面直接貼上attention-is-all-you-need-pytorch中MultiHeadAttention代碼
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#這里把batch和分塊數放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,這里的batch_size和seq_len均不為1,其把序列視為一個整體,求Q和V的相似度可使用點乘(V可以視為上面提及的encoder_outputs),獲得的是一個相似度矩陣,比如Q是一個長度為10的序列,K是一個長度為16的序列,其相似度矩陣就是一個10*16的矩陣,再如矩陣第一行表示Q的第一個單詞和K序列所有單詞的相似度。
similarity:=(lq,dk)?(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)?(dk,lk)=(lq,lk)
然后,生成帶注意力機制的詞向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而這里使用的是殘差(類似于unet和resnet),最后使用PositionwiseFeedForward(2個fc層)來融合attn_applied和input,最終生成帶注意力機制的詞向量。
attention_applied=(lq,lk)?(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)?(lv,dv)=(lq,dv)
細節部分
在數據預處理部分,對序列s都進行了首尾標記,比如s=’’+ s + ‘’,剛看transform(之前跳過了seq2seq),對下面的代碼甚是不解
decoder_input=target_seq[:, :-1] #這里不是去掉終止標記<eos>,去掉的可能是padding_0,只為兼容target_ground_y的序列長度?
encoder_input=input_seq[:, 1:] #encoder的輸入序列去掉了起始標記<sos>
target_ground_y= target_seqtrg[:, 1:] #用于計算模型loss的target,去掉了起始標記<sos>
1
2
3
其實在pytorch官方教程中說的比較清楚,看下圖
encoder的輸入序列和ground_true只需要一個終止符即可,而decoder的輸入序列開始必須指定一個起始符,讓其根據context預測輸出序列的第一個單詞,后面根據前一個單詞再預測下一個單詞,依次類推直到當前預測的單詞為終止標記’eos’,才計算loss.
---------------------
作者:PJ-Javis
來源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!
轉載于:https://www.cnblogs.com/jfdwd/p/11068075.html
總結
以上是生活随笔為你收集整理的pytorch笔记:09)Attention机制的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 从输入URL到页面加载的过程
- 下一篇: [docker] 04 使用docker