PyTorch环境下对BERT进行Fine-tuning
PyTorch環(huán)境下對(duì)BERT進(jìn)行Fine-tuning
本文根據(jù)Chris McCormick的BERT微調(diào)教程進(jìn)行優(yōu)化并使其適應(yīng)于數(shù)據(jù)集Quora Question Pairs里的判斷問(wèn)題對(duì)是否一致的任務(wù)。(文字部分大部分為原文的翻譯)
原文博客地址:https://mccormickml.com/2019/07/22/BERT-fine-tuning/
原文colab地址:https://colab.research.google.com/drive/1pTuQhug6Dhl9XalKB0zUGf4FIdYFlpcX
本文項(xiàng)目地址:https://github.com/yxf975/pretraining_models_learning
前言
本文對(duì)刪除了很多原英文博文中一些介紹性的內(nèi)容,著重于如何實(shí)現(xiàn)基礎(chǔ)的BERT微調(diào)方法。本解決方法不同于Chris McCormick的有以下幾點(diǎn):
- 使用的數(shù)據(jù)集為Quora問(wèn)題對(duì)數(shù)據(jù)集
- 添加了多gpu運(yùn)行的選擇
- 將部分代碼封裝進(jìn)了函數(shù)中,方便使用
- 添加了預(yù)測(cè)部分
具體對(duì)于BERT等預(yù)訓(xùn)練模型的原理的理解,我會(huì)單獨(dú)創(chuàng)建一個(gè)話題,讓我們直接開(kāi)始吧!
準(zhǔn)備工作
檢查GPU
為了讓 torch 使用 GPU,我們需要識(shí)別并指定 GPU 作為設(shè)備。稍后,在我們的訓(xùn)練循環(huán)中,我們將把數(shù)據(jù)加載到設(shè)備上。
import torch# If there's a GPU available... if torch.cuda.is_available(): # Tell PyTorch to use the GPU. device = torch.device("cuda")n_gpu = torch.cuda.device_count()print('There are %d GPU(s) available.' % n_gpu)print('We will use the GPU:', [torch.cuda.get_device_name(i) for i in range(n_gpu)])# If not... else:print('No GPU available, using the CPU instead.')device = torch.device("cpu")- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
安裝Transformer庫(kù)
目前,Hugging Face的Transformer庫(kù)似乎是最被廣泛接受的、最強(qiáng)大的與BERT合作的pytorch接口。除了支持各種不同的預(yù)先訓(xùn)練好的變換模型外,該庫(kù)還包含了這些模型的預(yù)構(gòu)建修改,適合你的特定任務(wù)。例如,在本教程中,我們將使用BertForSequenceClassification。
該庫(kù)還包括用于標(biāo)記分類(lèi)、問(wèn)題回答、下句預(yù)測(cè)等的特定任務(wù)類(lèi)。使用這些預(yù)建的類(lèi)可以簡(jiǎn)化為您的目的修改BERT的過(guò)程。
!pip install transformers- 1
加載Quora Question Pairs數(shù)據(jù)
數(shù)據(jù)集在kaggle官網(wǎng)上,注冊(cè)登錄即可下載,下載地址:https://www.kaggle.com/c/quora-question-pairs 。另外本人在google drive上也共享了數(shù)據(jù)集,下載地址:https://drive.google.com/drive/folders/1kFkte0Kt2xLe6Ykl4O4_TrL2iCzorOYk
Quora Question Pairs數(shù)據(jù)集介紹
這個(gè)數(shù)據(jù)集針對(duì)于Quora平臺(tái),很多人在Quora上會(huì)提出類(lèi)似措辭的問(wèn)題。具有相同意圖的多個(gè)問(wèn)題可能會(huì)導(dǎo)致搜尋者花費(fèi)更多時(shí)間來(lái)尋找問(wèn)題的最佳答案,并使作者感到他們需要回答同一問(wèn)題的多個(gè)版本。
該任務(wù)需要對(duì)問(wèn)題對(duì)是否重復(fù)進(jìn)行分類(lèi),從而解決自然語(yǔ)言處理問(wèn)題。這樣做將使查找問(wèn)題的高質(zhì)量答案變得更加容易,從而為Quora的作家,搜尋者和讀者帶來(lái)了更好的體驗(yàn)。
pandas加載數(shù)據(jù)
import pandas as pd import numpy as np# Load the dataset into a pandas dataframe. train_data = pd.read_csv("./train.csv", index_col="id",nrows=10000) train_data.head(6)- 1
- 2
- 3
- 4
- 5
- 6
這里我顯示6行,因?yàn)榈降诹胁庞袀€(gè)正樣本。
| 0 | 1 | 2 | What is the step by step guide to invest in share market in india? | What is the step by step guide to invest in share market? | 0 |
| 1 | 3 | 4 | What is the story of Kohinoor (Koh-i-Noor) Diamond? | What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? | 0 |
| 2 | 5 | 6 | How can I increase the speed of my internet connection while using a VPN? | How can Internet speed be increased by hacking through DNS? | 0 |
| 3 | 7 | 8 | Why am I mentally very lonely? How can I solve it? | Find the remainder when [math]23^{24}[/math] is divided by 24,23? | 0 |
| 4 | 9 | 10 | Which one dissolve in water quikly sugar, salt, methane and carbon di oxide? | Which fish would survive in salt water? | 0 |
| 5 | 11 | 12 | Astrology: I am a Capricorn Sun Cap moon and cap rising…what does that say about me? | I’m a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me? | 1 |
我們實(shí)際關(guān)心的三個(gè)屬性是"question1",“question1"和它們的標(biāo)簽"is_duplicate”,這個(gè)標(biāo)簽被稱(chēng)為"是否重復(fù)"(0=不重復(fù),1=重復(fù))。
訓(xùn)練集驗(yàn)證集拆分
把我們的訓(xùn)練集分成 80% 用于訓(xùn)練,20% 用于驗(yàn)證。
from sklearn.model_selection import train_test_split# train_validation data split X_train, X_val, y_train, y_val = train_test_split(train_data[["question1", "question2"]], train_data["is_duplicate"], test_size=0.2, random_state=405633)- 1
- 2
- 3
- 4
Tokenization & Input 格式化
BERT Tokenizer
from transformers import BertTokenizer# load bert tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)- 1
- 2
- 3
- 4
查看數(shù)據(jù)中句子的最長(zhǎng)長(zhǎng)度
#calculate the maximum sentence length max_len = 0 for _, row in train_data.iterrows():max_len = max(max_len, len(tokenizer(row['question1'],row['question2'])["input_ids"]))print("max token length of the input:", max_len)# set the maximum token length max_length = pow(2,int(np.log2(max_len)+1)) print("max token length for BERT:", max_length)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
轉(zhuǎn)換為BERT輸入
from torch.utils.data import TensorDataset# func to convert data to bert input def convert_to_dataset_torch(data: pd.DataFrame, labels = pd.Series(data=None)) -> TensorDataset:input_ids = []attention_masks = []token_type_ids = []for _, row in tqdm(data.iterrows(), total=data.shape[0]):encoded_dict = tokenizer.encode_plus(row["question1"], row["question2"], max_length=max_length, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', truncation=True)# Add the encoded sentences to the list.input_ids.append(encoded_dict['input_ids'])token_type_ids.append(encoded_dict["token_type_ids"])# And its attention mask (simply differentiates padding from non-padding).attention_masks.append(encoded_dict['attention_mask'])# Convert the lists into tensors.input_ids = torch.cat(input_ids, dim=0)token_type_ids = torch.cat(token_type_ids, dim=0)attention_masks = torch.cat(attention_masks, dim=0)if labels.empty:return TensorDataset(input_ids, attention_masks, token_type_ids)else:labels = torch.tensor(labels.values)return TensorDataset(input_ids, attention_masks, token_type_ids, labels)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 1
- 2
將數(shù)據(jù)放入DataLoader
我們還將使用 torch DataLoader 類(lèi)為我們的數(shù)據(jù)集創(chuàng)建一個(gè)迭代器。這有助于在訓(xùn)練過(guò)程中節(jié)省內(nèi)存,因?yàn)榕cfor循環(huán)不同,有了迭代器,整個(gè)數(shù)據(jù)集不需要加載到內(nèi)存中。
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler# set batch size for DataLoader(options from paper:16 or 32) batch_size = 32# Create the DataLoaders for training and validation sets train_dataloader = DataLoader(train, sampler = RandomSampler(train), # Select batches randomlybatch_size = batch_size )# For validation validation_dataloader = DataLoader(validation, sampler = SequentialSampler(validation), # Pull out batches sequentially.batch_size = batch_size )- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
加載模型
加載預(yù)訓(xùn)練模型BertForSequenceClassification
我們將使用BertForSequenceClassification。這是普通的BERT模型,上面增加了一個(gè)用于分類(lèi)的單線性層,我們將使用它作為句子分類(lèi)器。當(dāng)我們輸入數(shù)據(jù)時(shí),整個(gè)預(yù)先訓(xùn)練好的BERT模型和額外的未經(jīng)訓(xùn)練的分類(lèi)層會(huì)根據(jù)我們的特定任務(wù)進(jìn)行訓(xùn)練。
from transformers import BertForSequenceClassification, AdamW, BertConfig# Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. model = BertForSequenceClassification.from_pretrained("bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.num_labels = 2, # The number of output labels--2 for binary classification.# You can increase this for multi-class tasks. output_attentions = False, # Whether the model returns attentions weights.output_hidden_states = False, # Whether the model returns all hidden-states. )# Tell pytorch to run this model on the GPU. model.cuda() if n_gpu > 1:model = torch.nn.DataParallel(model)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
當(dāng)然也可以對(duì)BERT網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行修改以適應(yīng)我們的任務(wù),這里我就直接使用原模型。
優(yōu)化器 & 學(xué)習(xí)率調(diào)度器
為了微調(diào)的目的,BERT論文的作者建議從以下數(shù)值中選擇(來(lái)自BERT論文的附錄A.3)。
- batch大小:?16,32。(在Dataloader里設(shè)置)
- 學(xué)習(xí)率(Adam):?5e-5、3e-5、2e-5。
- epoch數(shù):?2、3、4。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
訓(xùn)練
時(shí)間規(guī)范函數(shù)
import time import datetime# Helper function for formatting elapsed times as hh:mm:ss def format_time(elapsed):'''Takes a time in seconds and returns a string hh:mm:ss'''# Round to the nearest second.elapsed_rounded = int(round((elapsed)))# Format as hh:mm:ssreturn str(datetime.timedelta(seconds=elapsed_rounded))- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
fit函數(shù)
from tqdm import tqdmdef fit_batch(dataloader, model, optimizer, epoch):total_train_loss = 0for batch in tqdm(dataloader, desc=f"Training epoch:{epoch+1}", unit="batch"):# Unpack batch from dataloader.input_ids = batch[0].to(device)attention_masks = batch[1].to(device)token_type_ids = batch[2].to(device)labels = batch[3].to(device)# clear any previously calculated gradients before performing a backward pass.model.zero_grad()# Perform a forward pass (evaluate the model on this training batch).outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_masks, labels=labels)loss = outputs[0]total_train_loss += loss.item()# Perform a backward pass to calculate the gradients.loss.backward()# normlization of the gradients to 1.0 to avoid exploding gradientstorch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)# Update parameters and take a step using the computed gradient.optimizer.step()# Update the learning rate.scheduler.step()return total_train_loss- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
驗(yàn)證評(píng)估函數(shù)
from sklearn.metrics import accuracy_scoredef eval_batch(dataloader, model, metric=accuracy_score):total_eval_accuracy = 0total_eval_loss = 0predictions , predicted_labels = [], []for batch in tqdm(dataloader, desc="Evaluating", unit="batch"):# Unpack batch from dataloader.input_ids = batch[0].to(device)attention_masks = batch[1].to(device)token_type_ids = batch[2].to(device)labels = batch[3].to(device)# Tell pytorch not to bother with constructing the compute graph during# the forward pass, since this is only needed for backprop (training).with torch.no_grad():# Forward pass, calculate logit predictions.outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_masks,labels=labels)loss = outputs[0]logits = outputs[1]total_eval_loss += loss.item()# Move logits and labels to CPUlogits = logits.detach().cpu().numpy()label_ids = labels.to('cpu').numpy()# Calculate the accuracy for this batch of validation sentences, and# accumulate it over all batches.y_pred = np.argmax(logits, axis=1).flatten()total_eval_accuracy += metric(label_ids, y_pred)predictions.extend(logits.tolist())predicted_labels.extend(y_pred.tolist())return total_eval_accuracy, total_eval_loss, predictions ,predicted_labels- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
訓(xùn)練函數(shù)
def train(train_dataloader, validation_dataloader, model, optimizer, epochs):# list to store a number of quantities such as # training and validation loss, validation accuracy, and timings.training_stats = []# Measure the total training time for the whole run.total_t0 = time.time()for epoch in range(0, epochs):# Measure how long the training epoch takes.t0 = time.time()# Reset the total loss for this epoch.total_train_loss = 0# Put the model into training mode. model.train()total_train_loss = fit_batch(train_dataloader, model, optimizer, epoch)# Calculate the average loss over all of the batches.avg_train_loss = total_train_loss / len(train_dataloader)# Measure how long this epoch took.training_time = format_time(time.time() - t0)t0 = time.time()# Put the model in evaluation mode--the dropout layers behave differently# during evaluation.model.eval()total_eval_accuracy, total_eval_loss, _, _ = eval_batch(validation_dataloader, model)# Report the final accuracy for this validation run.avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)print("\n")print(f"score: {avg_val_accuracy}")# Calculate the average loss over all of the batches.avg_val_loss = total_eval_loss / len(validation_dataloader)# Measure how long the validation run took.validation_time = format_time(time.time() - t0)print(f"Validation Loss: {avg_val_loss}")print("\n")# Record all statistics from this epoch.training_stats.append({'epoch': epoch,'Training Loss': avg_train_loss,'Valid. Loss': avg_val_loss,'Valid. score.': avg_val_accuracy,'Training Time': training_time,'Validation Time': validation_time})print("")print("Training complete!")print(f"Total training took {format_time(time.time()-total_t0)}")return training_stats- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
開(kāi)始訓(xùn)練
import random# Set the seed value all over the place to make this reproducible. seed_val = 2020random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) if n_gpu > 0:torch.cuda.manual_seed_all(seed_val)training_stats = train(train_dataloader, validation_dataloader, model, optimizer, epochs)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
查看訓(xùn)練過(guò)程中的的評(píng)估數(shù)據(jù)
df_stats = pd.DataFrame(training_stats).set_index('epoch') df_stats- 1
- 2
預(yù)測(cè)
預(yù)測(cè)函數(shù)
def predict(dataloader, model):prediction = list()for batch in tqdm(dataloader, desc="predicting", unit="batch"):# Unpack batch from dataloader.input_ids = batch[0].to(device)attention_masks = batch[1].to(device)token_type_ids = batch[2].to(device)# Tell pytorch not to bother with constructing the compute graph during# the forward pass, since this is only needed for backprop (training).with torch.no_grad():# Forward pass, calculate logit predictions.outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_masks)logits = outputs[0]# Move logits and labels to CPUlogits = logits.detach().cpu().numpy()prediction.append(logits)pred_logits = np.concatenate(prediction, axis=0)pred_label = np.argmax(pred_logits, axis=1).flatten()print("done")return (pred_label,pred_logits)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
為測(cè)試集創(chuàng)建Dataloader
# Create the DataLoader for test data. prediction_data = convert_to_dataset_torch(test_data) prediction_sampler = SequentialSampler(prediction_data) prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)- 1
- 2
- 3
- 4
預(yù)測(cè)
也可以用softmax將logits轉(zhuǎn)化為相應(yīng)的概率
y_pred,logits = predict(prediction_dataloader,model) # get the corresponding probablities prob = torch.nn.functional.softmax(torch.tensor(logits))- 1
- 2
- 3
總結(jié)
本篇文章演示了利用預(yù)先訓(xùn)練好的 BERT 模型,微調(diào)適應(yīng)于Quora問(wèn)題對(duì)任務(wù)。在面對(duì)其他類(lèi)似的文本分類(lèi)問(wèn)題時(shí)也可以采取類(lèi)似的微調(diào)方法。
當(dāng)然如果想要更精確的更好的預(yù)測(cè)結(jié)果,可能需要使用更好的更合適的預(yù)訓(xùn)練模型,修改網(wǎng)絡(luò)模型使之更適合當(dāng)前任務(wù),或者加入對(duì)抗訓(xùn)練等方法。
總結(jié)
以上是生活随笔為你收集整理的PyTorch环境下对BERT进行Fine-tuning的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Transformer: Trainin
- 下一篇: 深入Bert实战(Pytorch)---