【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型
@Author:Runsen
BERT模型在NLP各項任務中大殺四方,那么我們如何使用這一利器來為我們日常的NLP任務來服務呢?首先介紹使用BERT做文本多標簽分類任務。
文本多標簽分類是常見的NLP任務,文本介紹了如何使用Bert模型完成文本多標簽分類,并給出了各自的步驟。
參考官方教程:https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html
復旦大學邱錫鵬老師課題組的研究論文《How to Fine-Tune BERT for Text Classification?》。
論文: https://arxiv.org/pdf/1905.05583.pdf
這篇論文的主要目的在于在文本分類任務上探索不同的BERT微調方法并提供一種通用的BERT微調解決方法。這篇論文從三種路線進行了探索:
- (1) BERT自身的微調策略,包括長文本處理、學習率、不同層的選擇等方法;
- (2) 目標任務內、領域內及跨領域的進一步預訓練BERT;
- (3) 多任務學習。微調后的BERT在七個英文數(shù)據(jù)集及搜狗中文數(shù)據(jù)集上取得了當前最優(yōu)的結果。
作者的實現(xiàn)代碼: https://github.com/xuyige/BERT4doc-Classification
數(shù)據(jù)集來源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv
該數(shù)據(jù)集包含 6 個不同的標簽(計算機科學、物理、數(shù)學、統(tǒng)計學、生物學、金融),以根據(jù)摘要和標題對研究論文進行分類。
標簽列中的值 1 表示標簽屬于該標簽。每個論文有多個標簽為 1。
Bert模型加載
Transformer 為我們提供了一個基于 Transformer 的可以微調的預訓練網(wǎng)絡。
由于數(shù)據(jù)集是英文, 因此這里選擇加載bert-base-uncased。
具體下載鏈接:https://huggingface.co/bert-base-uncased/tree/main
from transformers import BertTokenizerFast as BertTokenizer # 直接下載很很慢,建議下載到文件夾中 # BERT_MODEL_NAME = "bert-base-uncased" BERT_MODEL_NAME = "model/bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)微調BERT模型
bert微調就是在預訓練模型bert的基礎上只需更新后面幾層的參數(shù),這相對于從頭開始訓練可以節(jié)省大量時間,甚至可以提高性能,通常情況下在模型的訓練過程中,我們也會更新bert的參數(shù),這樣模型的性能會更好。
微調BERT模型主要在D_out進行相關的改變,去除segment層,直接采用了字符輸入,不再需要segment層。
下面是微調BERT的主要代碼
class BertClassifier(nn.Module):def __init__(self, num_labels: int, BERT_MODEL_NAME, freeze_bert=False):super().__init__()self.num_labels = num_labelsself.bert = BertModel.from_pretrained(BERT_MODEL_NAME)# hidden size of BERT, hidden size of our classifier, and number of labels to classifyD_in, H, D_out = self.bert.config.hidden_size, 50, num_labels# Instantiate an one-layer feed-forward classifierself.classifier = nn.Sequential(nn.Dropout(p=0.3),nn.Linear(D_in, H),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(H, D_out),)# lossself.loss_func = nn.BCEWithLogitsLoss()if freeze_bert:print("freezing bert parameters")for param in self.bert.parameters():param.requires_grad = Falsedef forward(self, input_ids, attention_mask, labels=None):outputs = self.bert(input_ids, attention_mask=attention_mask)last_hidden_state_cls = outputs[0][:, 0, :]logits = self.classifier(last_hidden_state_cls)if labels is not None:predictions = torch.sigmoid(logits)loss = self.loss_func(predictions.view(-1, self.num_labels), labels.view(-1, self.num_labels))return losselse:return logits其他
關于數(shù)據(jù)預處理,DataLoader等代碼有點多,這里不一一列舉,需要代碼的在公眾號回復:”bert“ 。
最后的訓練結果如下所示:
總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习和目标检测系列教程 17-300
- 下一篇: 【小白学习C++ 教程】二十一、C++