Hierarchical Attention Networks for Document Classification(HAN)
HAN歷史意義:
? ? ? ? ?1、基于Attention的文本分類模型得到了很多關注
? ? ? ? ?2、通過層次處理長文檔的方式逐漸流行
? ? ? ? ?3、推動了注意力機制在非Seqseq模型上的應用
?
前人主要忽視的問題:
? ? ? ? ? 1、文檔中不同句子對于分類的重要性不同
? ? ? ? ? 2、句子中不同單詞對于分類的重要性也有所不同
?
本文主要結構
一、Abstract
? ? ? ?(通??蚣転?#xff1a;任務的重要性 -> 前人缺點 -> 本文模型 -> 實驗)
? ? ? ? 提出了一種針對文檔分類任務的層次注意力網絡,既包含了一種層次結構,又在詞級別和句子級別使用兩種注意力機制來選擇重要的信息
二、Introduction
? ? ? ?主要概括: 之前深度學習模型取得比較好的效果,但是沒有注意到文檔中不同部分對任務的重要度不同,基于此提出了層次注意力網絡
? ? ? ?具體背景:
? ? ? ? ?1、文本分類是自然語言的基礎任務之一,研究者也開始使用基于深度學習的文本分類模型
? ? ? ? ?2、雖然深度學習的文本分類模型取得非常好的效果,但是沒有注意文檔的結構,并且沒有注意到文檔中不同部分對于分類的影響程度不一樣
? ? ? ? ?3、為了解決這一個問題,提出了一種層次注意力網絡來學習文檔的層次結構,并且使用兩種注意力機制學習基于上下文結構的重要性
? ? ? ? ?4、與前人的區別是使用上下文來區分句子或單詞的重要性,而不僅僅使用單個句子或單個的詞
三、Hierarchical Attention Networks
? ? ? ? 首先介紹了GRU網絡
? ? ?
? ? ?GRU網絡圖如下所示:
? ?
Attention機制指的是從大量的信息中抽取對任務重要的信息,所以能夠抽取文檔中重要的句子以及句子中重要的單詞,結構如下所示:
?Hierarchical Attantion Networks(HAN)模型主要包含四部分Word Encoder、Word Attention、Sentence Encoder、Sentence Attention
Word Encoder:
主要是輸入詞獲取詞向量矩陣,然后將詞向量輸入雙向GRU網絡中得到GRU網絡的輸出,部分代碼片段如下:
""" 定義結構 """ if is_pretrain:self.embedding = nn.Embedding.from_pretrained(weights, freeze=False) else:self.embedding = nn.Embedding(vocab_size, embedding_size)self.word_gru = nn.GRU(input_size=embedding_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)""" 具體實現片段 """x_embedding = self.embedding(x) word_outputs, word_hidden = self.word_gru(x_embedding)Word Attention
?
主要實現一個詞級別的attention機制,在這個里面u對應的是注意力機制的query,不同的是這里的query是個變量也根據模型進行迭代優化,key對應的是gru網絡的輸出,attention value就是query和key對應計算value值,然后在迭代加和,部分代碼片段如下所示:
""" 定義結構 """self.word_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True) # 論文中提到的u也就是query,因為需要更新迭代所以這里面寫的是nn.Parameterself.word_dense = nn.Linear(2*gru_size,2*gru_size) # 定義一個全連接網絡""" 具體實現 """""" 對應論文中的公式 """ attention_word_outputs = torch.tanh(self.word_dense(word_outputs)) weights = torch.matmul(attention_word_outputs,self.word_context) weights = F.softmax(weights,dim=1)""" 有一些部分為0,所以權重矩陣對應位置有參數也沒有意義,所以做mask""" x = x.unsqueeze(2) if gpu:weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda()) else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights,dim=1).unsqueeze(1)+1e-4)Sentence?Encoder
?
首先獲取句子向量表示,主要由兩部分構成一部分是詞級別通過gru網絡的輸出,另一個部分是attention機制計算出的對應權重,把這兩個部分進行加權求和得到句子向量表示,然后將句子向量還是輸入到雙向GRU網絡中得到輸出結果,代碼片段如下所示:
""" 定義結構 """self.sentence_gru = nn.GRU(input_size=2*gru_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)""" 具體片段 """sentence_vector = torch.sum(word_outputs*weights,dim=1).view([-1,sentence_num,word_outputs.shape[-1]]) sentence_outputs, sentence_hidden = self.sentence_gru(sentence_vector)Sentence Attention
?
通過定義句子級別的attention,然后獲取每個句子的權重,最后得到文檔表示,具體代碼片段如下所示:
""" 定義網絡結構 """self.sentence_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True) self.sentence_dense = nn.Linear(2*gru_size,2*gru_size)self.fc = nn.Linear(2*gru_size,class_num) # 最后文檔表示做全連接分類""" 具體片段 """attention_sentence_outputs = torch.tanh(self.sentence_dense(sentence_outputs)) weights = torch.matmul(attention_sentence_outputs,self.sentence_context) weights = F.softmax(weights,dim=1) x = x.view(-1, sentence_num, x.shape[1]) x = torch.sum(x, dim=2).unsqueeze(2) if gpu:weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda()) else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))"""權重歸一化""" weights = weights / (torch.sum(weights,dim=1).unsqueeze(1)+1e-4)""" 獲得文檔向量表示 """ document_vector = torch.sum(sentence_outputs*weights,dim=1)""" 最后根據文檔表示對文檔進行分類""" output = self.fc(document_vector)?
四、Experiment
? ? ? 實驗結果部分主要是在相同的數據集上和其它模型做對比表現出該模型的效果比較好,并且根據分布解釋了“good”或“bad”在分布是不一樣的
五、Related Work
? ? ? 相關工作主要是解釋了其它論文的作者是采用什么方法,怎么實現的相當于做一個對比也是一個鋪墊
六、Conclusion
?
關鍵點:
? ? ? 1、之前基于深度學習的文本分類模型沒有關注到文檔中不同部分的信息重要性不同
? ? ? ?2、通過注意力機制可以學習到文檔中各個部分對于分類的重要性
? ? ? ?3、提出HAN Attention模型
?
創新點:
? ? ? ?1、提出一種新的文本分類模型-HAN Attention模型
? ? ? ?2、通過兩種級別的注意力機制同時學習文檔中重要的句子和單詞
? ? ? ?3、在幾個文本分類數據集上取得比較好(State of the art)的效果
?
啟發點:
? ? ? ?1、模型背后的直覺是文檔不同部分對于文檔分類的重要性不同,而且這些部分的重要性還取決于內部的單詞,而不僅僅是對這部分單獨確定重要性
? ? ? ?2、單詞和句子的重要性是上下文相關的,同樣的詞或者句子在不同的上下文情景下重要性也不同
?
七、代碼實現
? ?IMDB公開數據集下載地址:http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp-2015-data.7z
? ?
""" 數據預處理部分 """from torch.utils import data import os import nltk import numpy as np import pickle from collections import Counter""" 數據集加載 """# 數據集加載 datas = open("./data/imdb/imdb-test.txt.ss",encoding="utf-8").read().splitlines() datas = [data.split(" ")[-1].split()+[data.split(" ")[2]] for data in datas] print(datas[0:1])[['i','knew','that','the','old-time','movie','makers','often','``','borrowed',"''",'or','outright','plagiarized','from','each','other',',','but','this','is','ridiculous','!','<sssss>','not','only','did','george','albert','smith','make','this','film','in','1899',',','but','and','company','made','a','nearly','identical','film','that','same','year','with','the','same','title','!!!','<sssss>','the','worst','part','about','it','is','that','neither','film','was','all','that','great','.','<sssss>','and',',','of','the','two',',','the','smith','one','is','slightly','less','well','made','.','<sssss>','like','all','movies','of','the','1890s',',','this','one','is','incredibly','brief','and','almost','completely','uninteresting','to','audiences','in','the','21st','century','.','<sssss>','only','film','historians','and','crazy','people','like','me','would','watch','this','brief','film','-lrb-','i',"'m",'a','history','teacher','and','film','lover','--','that',"'s",'my','excuse','for','watching','them','both','-rrb-','.','5']]# 根據長度排序,保證訓練時每個batch的長度一致datas = sorted(datas,key = lambda x:len(x),reverse=True) labels = [int(data[-1])-1 for data in datas] datas = [data[0:-1] for data in datas]print(labels[0:5]) print (datas[-5:])[7, 9, 9, 8, 9] [['one', 'of', 'the', 'best', 'movie', 'musicals', 'ever', 'made', '.', '<sssss>', 'the', 'singing', 'and', 'dancing', 'are', 'excellent', '.'], ['john', 'goodman', 'is', 'excellent', 'in', 'this', 'entertaining', 'portrayal', 'of', 'babe', 'ruth', "'s", 'life', '.'], ['how', 'to', 'this', 'movie', ':', 'disjointed', 'silly', 'unfulfilling', 'story', 'waste', 'of', 'time'], ['simply', 'a', 'classic', '.', '<sssss>', 'scenario', 'and', 'acting', 'are', 'excellent', '.'], ['there', 'were', 'tng', 'tv', 'episodes', 'with', 'a', 'better', 'story', '.']]# 構建 word2idmin_count = 5 word_freq = {} for data in datas:for word in data:word_freq[word] = word_freq.get(word,0)+1word2id = {"<pad>":0,"<unk>":1} for word in word_freq:if word_freq[word]<min_count:continueelse:word2id[word] = len(word2id)print(word2id){'<pad>': 0,'<unk>': 1,'i': 2,'only': 3,'just': 4,'got': 5,'around': 6,'to': 7,'watching': 8,'the': 9,'movie': 10,'today': 11,'.': 12,'<sssss>': 13,'when': 14,'it': 15,'came': 16,'out': 17,'in': 18,'movies': 19,',': 20,'heard': 21,'so': 22,'many': 23,'bad': 24,'things': 25,'about': 26,'...': 27,'how': 28,'fake': 29,'looked': 30,'long': 31,'winded': 32,'and': 33,'boring': 34,'was': 35,'stupid': 36,"n't": 37,'all': 38,'that': 39,'great': 40,'etc.': 41,'list': 42,'goes': 43,.........# 分句 for i,data in enumerate(datas):datas[i] = " ".join(data).split("<sssss>")for j,sentence in enumerate(datas[i]):datas[i][j] = sentence.split()# 將數據轉化為id max_sentence_length = 100 # 句子必須一樣的長度 batch_size = 64 # 每個batch size,每個文檔的句子一樣多 for i,document in enumerate(datas):for j,sentence in enumerate(document):for k,word in enumerate(sentence):datas[i][j][k] = word2id.get(word,word2id["<unk>"])datas[i][j] = datas[i][j][0:max_sentence_length] + \[word2id["<pad>"]]*(max_sentence_length-len(datas[i][j])) for i in range(0,len(datas),batch_size):max_data_length = max([len(x) for x in datas[i:i+batch_size]])for j in range(i,min(i+batch_size,len(datas))):datas[j] = datas[j] + [[word2id["<pad>"]]*max_sentence_length]*(max_data_length-len(datas[j]))"""得到最終輸入模型的數據-datas"""?
""" 模型構建部分 """# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:46 # @Author : TaoWang # @Description :from torch.nn import functional as F import torch.nn as nn import numpy as np import torchclass HAN_Model(nn.Module):def __init__(self,vocab_size,embedding_size,gru_size,class_num,is_pretrain=False,weights=None):""":param vocab_size::param embedding_size::param gru_size::param class_num::param is_pretrain::param weights:"""super(HAN_Model, self).__init__()if is_pretrain:self.embedding = nn.Embedding.from_pretrained(weights, freeze=False)else:self.embedding = nn.Embedding(vocab_size, embedding_size)self.word_gru = nn.GRU(input_size=embedding_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)self.word_context = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)self.word_dense = nn.Linear(2*gru_size, 2*gru_size)self.sentence_gru = nn.GRU(input_size=2*gru_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)self.sentence_context = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)self.sentence_dense = nn.Linear(2*gru_size, 2*gru_size)self.fc = nn.Linear(2*gru_size, class_num)def forward(self, x, gpu=False):""":param x::param gpu::return:"""sentence_num = x.shape[1]sentence_length = x.shape[2]x = x.view([-1, sentence_length])x_embedding = self.embedding(x)word_outputs, word_hidden = self.word_gru(x_embedding)attention_word_outputs = torch.tanh(self.word_dense(word_outputs))weights = torch.matmul(attention_word_outputs, self.word_context)weights = F.softmax(weights, dim=1)x = x.unsqueeze(2)if gpu:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float).cuda())else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights, dim=1).unsqueeze(1) + 1e-4)sentence_vector = torch.sum(word_outputs * weights, dim=1).view([-1, sentence_num, word_outputs.shape[-1]])sentence_outputs, sentence_hidden = self.sentence_gru(sentence_vector)attention_sentence_outputs = torch.tanh(self.sentence_dense(sentence_outputs))weights = torch.matmul(attention_sentence_outputs, self.sentence_context)weights = F.softmax(weights, dim=1)x = x.view(-1, sentence_num, x.shape[1])x = torch.sum(x, dim=2).unsqueeze(2)if gpu:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights, dim=1).unsqueeze(1) + 1e-4)document_vector = torch.sum(sentence_outputs * weights, dim=1)output = self.fc(document_vector)return outputif __name__ == "__main__":han_model = HAN_Model(vocab_size=30000, embedding_size=200, gru_size=50, class_num=4)x = torch.Tensor(np.zeros([64, 50, 100])).long()x[0][0][0:10] = 1output = han_model(x)print(output)?
""" 模型訓練部分 """# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:43 # @Author : TaoWang # @Description : 模型訓練過程import torch import torch.autograd as autograd import torch.nn as nn import torch.optim as optim from model import HAN_Model from data import IMDB_Data import numpy as np from tqdm import tqdm import config as argumentparserconfig = argumentparser.ArgumentParser() torch.manual_seed(config.seed)if config.cuda and torch.cuda.is_available(): # 是否使用gputorch.cuda.set_device(config.gpu)# 導入訓練集 training_set = IMDB_Data("imdb-train.txt.ss",min_count=config.min_count,max_sentence_length = config.max_sentence_length,batch_size=config.batch_size,is_pretrain=False) training_iter = torch.utils.data.DataLoader(dataset=training_set,batch_size=config.batch_size,shuffle=False,num_workers=0)# 導入測試集 test_set = IMDB_Data("imdb-test.txt.ss",min_count=config.min_count,word2id=training_set.word2id,max_sentence_length = config.max_sentence_length,batch_size=config.batch_size) test_iter = torch.utils.data.DataLoader(dataset=test_set,batch_size=config.batch_size,shuffle=False,num_workers=0)model = HAN_Model(vocab_size=len(training_set.word2id),embedding_size=config.embedding_size,gru_size = config.gru_size,class_num=config.class_num,weights=training_set.weight,is_pretrain=False)if config.cuda and torch.cuda.is_available(): # 如果使用gpu,將模型送進gpumodel.cuda()criterion = nn.CrossEntropyLoss() # 這里會做softmax optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) loss = -1def get_test_result(data_iter,data_set):# 生成測試結果model.eval()true_sample_num = 0for data, label in data_iter:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()if config.cuda and torch.cuda.is_available():out = model(data, gpu=True)else:out = model(data)true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy())acc = true_sample_num / data_set.__len__()return accfor epoch in range(config.epoch):model.train()process_bar = tqdm(training_iter)for data, label in process_bar:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()label = torch.autograd.Variable(label).squeeze()if config.cuda and torch.cuda.is_available():out = model(data,gpu=True)else:out = model(data)loss_now = criterion(out, autograd.Variable(label.long()))if loss == -1:loss = loss_now.data.item()else:loss = 0.95*loss+0.05*loss_now.data.item()process_bar.set_postfix(loss=loss_now.data.item())process_bar.update()optimizer.zero_grad()loss_now.backward()optimizer.step()test_acc = get_test_result(test_iter, test_set)print("The test acc is: %.5f" % test_acc) """ 配置文件-相關配置參數"""# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:43 # @Author : TaoWang # @Description :import argparsedef ArgumentParser():parser = argparse.ArgumentParser()parser.add_argument('--embed_size', type=int, default=10, help="embedding size of word embedding")parser.add_argument("--epoch", type=int, default=200, help="epoch of training")parser.add_argument("--cuda", type=bool, default=True, help="whether use gpu")parser.add_argument("--gpu", type=int, default=2, help="gpu num")parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate during training")parser.add_argument("--batch_size", type=int, default=64, help="batch size during training")parser.add_argument("--seed", type=int, default=0, help="seed of random")parser.add_argument("--min_count", type=int, default=5, help="min count of words")parser.add_argument("--max_sentence_length", type=int, default=100, help="max sentence length")parser.add_argument("--embedding_size", type=int, default=200, help="word embedding size")parser.add_argument("--gru_size", type=int, default=50, help="gru size")parser.add_argument("--class_num", type=int, default=10, help="class num")return parser.parse_args()?
總結
以上是生活随笔為你收集整理的Hierarchical Attention Networks for Document Classification(HAN)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Bag of Tricks for Ef
- 下一篇: SGM:Sequence Generat