自然语言推断:微调BERT
文章目錄
- 自然語言推斷:微調BERT
- 1 - 加載預訓練的BERT
- 2 - 微調BERT的數據集
- 3 - 微調BERT
- 4 - 小結
自然語言推斷:微調BERT
在前幾章中,我們已經為SNLI數據集上的自然語言推斷任務設計了一個基于注意力的結構。現在,我們通過微調BERT來重新審視這項任務,自然語言推斷任務時一個序列級別的文本對分類問題,而微調BERT只需要一個額外的基于多層感知機的架構
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-ArZ3KHum-1665475659087)(images/f1.png)]
在本節中,我們將下載一個預訓練好的小版本BERT,然后對其進行微調,以便在SNLI數據集上進行自然語言推斷
1 - 加載預訓練的BERT
我們以前在WikiText-2數據集上預訓練BERT(注意,原始的BERT模型是在更大的語料庫上預訓練的),原始的BERT模型有數以億計的參數。在下面,我們提供了兩個版本的預訓練的BERT:"bert.base"與原始的BERT基礎模型一樣大,需要大量的計算資源才能進行微調。而"bert.small"是一個小版本,便于演示
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip', '225d66f04cae318b841a13d32af3acc165f253ac') d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip', 'c72329e68a732bef0452e4b96a1c341c8910f81f')兩個預訓練好的BERT模型都包含一個定義詞表的“vocab.json”文件和一個預訓練參數的“pretrained.params”文件。我們實現了以下load_pretrained_model函數來加載預先訓練好的BERT參數
def load_pretrained_model(pretrained_model,num_hiddens,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,devices):data_dir = d2l.download_extract(pretrained_model)# 定義空詞表以加載預定義詞表vocab = d2l.Vocab()vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))vocab.token_to_idx = {token: idx for idx,token in enumerate(vocab.idx_to_token)}bert = d2l.BERTModel(len(vocab),num_hiddens,norm_shape=[256],ffn_num_input=256,ffn_num_hiddens=ffn_num_hiddens,num_heads=4,num_layers=2,dropout=0.2,max_len=max_len,key_size=256,query_size=256,value_size=256,hid_in_features=256,mlm_in_features=256,nsp_in_features=256)# 加載預訓練BERT參數bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))return bert,vocab為了便于在大多數機器上演示,我們將在本節中加載和微調經過預訓練的BERT小版本(“bert.small”)。在練習中,我們將展示如何微調大得多的“bert.base”以顯著提高測試精度
devices = d2l.try_all_gpus() bert,vocab = load_pretrained_model('bert.small',num_hiddens=256,ffn_num_hiddens=512,num_heads=4,num_layers=2,dropout=0.1,max_len=512,devices=devices)2 - 微調BERT的數據集
對于SNLI數據集的下游任務自然語言推斷,我們定義了一個定制的數據集類SNLIBERTDataset。在每個樣本中,前提和假設形成一對文本序列,并被打包成一個BERT輸入序列。利用預定義的BERT輸入序列的最大長度(max_len),持續移除輸入文本對中較長文本的最后一個標記,直到滿足max_len。為了加速生成用于微調BERT的SNLI數據集,我們使用4個工作進程并行生成訓練或測試樣本
class SNLIBERTDataset(torch.utils.data.Dataset):def __init__(self,dataset,max_len,vocab=None):all_premise_hypothesis_tokens = [[p_tokens,h_tokens] for p_tokens,h_tokens in zip(*[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])]self.labels = dataset[2]self.vocab = vocabself.max_len = max_len(self.all_token_ids,self.all_segments,self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)print('read' + str(len(self.all_token_ids)) + ' examples')def _preprocess(self,all_premise_hypothesis_tokens):pool = multiprocessing.Pool(4) # 使用4個進程out = pool.map(self._mp_worker,all_premise_hypothesis_tokens)all_token_ids = [token_ids for token_ids,segments,valid_len in out]all_segments = [segments for token_ids,segments,valid_len in out]valid_lens = [valid_len for token_ids,segments,valid_len in out]return (torch.tensor(all_token_ids,dtype=torch.long),torch.tensor(all_segments,dtype=torch.long),torch.tensor(valid_lens))def _mp_worker(self,premise_hypothesis_tokens):p_tokens,h_tokens = premise_hypothesis_tokensself._truncate_pair_of_tokens(p_tokens,h_tokens)tokens,segments = d2l.get_tokens_and_segments(p_tokens,h_tokens)token_ids = self.vocab[tokens] + [self.vocab['<pad>']] * (self.max_len - len(tokens))segments = segments + [0] * (self.max_len - len(segments))valid_len = len(tokens)return token_ids,segments,valid_lendef _truncate_pair_of_tokens(self,p_tokens,h_tokens):# 為BERT輸入中的'<CLS>'、'<SEP>'和'<SEP>'詞元保留位置while len(p_tokens) + len(h_tokens) > self.max_len - 3:if len(p_tokens) > len(h_tokens):p_tokens.pop()else:h_tokens.pop()def __getitem__(self,idx):return (self.all_token_ids[idx],self.all_segments[idx],self.valid_lens[idx],self.labels[idx])def __len__(self):return len(self.all_token_ids)下載完SNLI數據集后,我們通過實例化SNLIBERTDataset類來生成訓練和測試樣本,這些樣本在自然語言推斷的訓練和測試期間進行小批量讀取
import os import redef read_snli(data_dir,is_train):"""將SNLI數據集解析為前提、假設和標簽"""def extract_text(s):# 刪除我們不會使用的信息s = re.sub('\\(','',s)s = re.sub('\\(','',s)# 用一個空格替換兩個或多個連續的空格s = re.sub('\\s{2,}',' ',s)return s.strip()#label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}label_set = {0:'entailment',1: 'contradiction',2: 'neutral'}file_name = os.path.join(data_dir, 'train.txt' if is_train else 'test.txt')with open(file_name,'r') as f:rows = [row.split('\t') for row in f.readlines()[1:]]premises = [extract_text(row[0]) for row in rows]hypotheses = [extract_text(row[1]) for row in rows]labels = [label_set[int(row[2].replace('\n',''))] for row in rows]return premises,hypotheses,labels # 如果出現顯存不足的錯誤,請減少“batch_size”,在原始的BERT模型中,max_len=512batch_size,max_len,num_workers = 512,128,d2l.get_dataloader_workers() data_dir = 'SNLI' train_set = SNLIBERTDataset(read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(read_snli(data_dir, False), max_len, vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size,num_workers=num_workers)3 - 微調BERT
用于自然語言推斷的微雕BERT只需要一個額外的多層感知機,該多層感知機由兩個全連接層組成(參見下面BERTClassifier類中的self.hidden和self.output)。這個多層感知機將特殊的“<cls>”詞元的BERT表示進行了轉換,該詞元同時編碼前提和假設的信息為自然語言推斷的三個輸出:蘊涵、矛盾和中性
class BERTClassifier(nn.Module):def __init__(self,bert):super(BERTClassifier,self).__init__()self.encoder = bert.encoderself.hidden = bert.hiddenself.output = nn.Linear(256,3)def forward(self,inputs):tokens_X,segments_X,valid_lens_x = inputsencoded_X = self.encoder(tokens_X,segments_X,valid_lens_x)return self.output(self.hidden(encoded_X[:,0,:]))在下文中,預訓練的BERT模型bert被送到用于下游應用的BERTClassifier實例net中,在BERT微調的常見實現中,只有額外的多層感知機(net.output)的輸出層的參數將從零開始學習。預訓練BERT編碼器(net.encoder)和額外的多層感知機的隱藏層(net.hidden)的所有參數都將進行微調
net = BERTClassifier(bert)回想一下,MaskLM類和NextSentencePred類在其使用的多層感知機中都有一些參數,這些參數是預訓練BERT模型bert中參數的一部分,因此是net中參數的一部分。然而,這些參數僅用于計算預訓練過程中的遮蔽語言模型損失和下一句預測損失。這兩個損失函數與微調下游應用無關,因此當BERT微調時,MaskLM和NextSentencePred中采用的多層感知機的參數不會更新(陳舊的,staled)
為了允許具有陳舊梯度的參數,標注ignore_stale_grad=Ture在step函數d2l.train_batch_ch13中被設置。我們通過該函數使用SNLI訓練集(train_iter)和測試集(test_iter)對net模型進行訓練和評估。由于計算資源有限,訓練和測試精度可以進一步提高:我們把對它的討論留在練習中
lr,num_epochs = 1e-4,5 trainer = torch.optim.Adam(net.parameters(),lr=lr) loss = nn.CrossEntropyLoss(reduction='none') d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)4 - 小結
- 我們可以針對下游應用對預訓練的BERT模型進行微調,例如在SNLI數據集上進行自然語言推斷
- 在微調過程中,BERT模型成為下游應用模型的一部分。僅與訓練前損失相關的參數在微調期間不會更新
總結
以上是生活随笔為你收集整理的自然语言推断:微调BERT的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python中path函数_示例1-pa
- 下一篇: 社科院与杜兰大学金融管理硕士项目——苦练