KBQA-Bert学习记录-CRF模型
目錄
一、整體架構(gòu)
1.定義CRF類(lèi),初始化相關(guān)參數(shù)
2.定義forward函數(shù)
3.forword調(diào)用的函數(shù):_validate
4.forward調(diào)用的函數(shù):_conputer_score
5.forward調(diào)用的函數(shù):_compute_normalizer
6.forward調(diào)用的函數(shù):_viterbi_decode
7.外部調(diào)用的函數(shù):decode
該項(xiàng)目中,使用BERT+CRF進(jìn)行NER任務(wù),因此先構(gòu)造CRF模型。具體實(shí)現(xiàn)過(guò)程中需要注意的細(xì)節(jié)已在代碼中包含。
一、整體架構(gòu)
通過(guò)bert生成序列之后(其他的模型比如LSTM什么的也一樣,都會(huì)生成一個(gè)預(yù)測(cè)序列),我們得到了形狀是(batch_size, sentence_length, number_of_tags)的結(jié)果,也就是,對(duì)每一句話(huà),的每一個(gè)字,有number_of_tags這么多的預(yù)測(cè)結(jié)果。假如我們的實(shí)體類(lèi)型有三個(gè)"B", "I", "O",一個(gè)batch有32句話(huà),一句話(huà)被統(tǒng)一成了64個(gè)單詞,那么生成的結(jié)果就是(32, 64, 3)。注意這里的batch_size和sentence_length的位置,可能會(huì)由于代碼的不同,調(diào)換順序。
生成的結(jié)果就是我們的發(fā)射分?jǐn)?shù)。要計(jì)算損失,我們還需要計(jì)算發(fā)射分?jǐn)?shù)中,正確路徑對(duì)應(yīng)的分?jǐn)?shù);以及發(fā)射分?jǐn)?shù)中,所有路徑合起來(lái)的分?jǐn)?shù)。
同時(shí),我們還需要對(duì)所有路徑合起來(lái)的分?jǐn)?shù)進(jìn)行處理,由于計(jì)算損失的時(shí)候,會(huì)讓這個(gè)總分作為分母,因此,采用的是先取exp(),再求和sum(),再取對(duì)數(shù)log(),而這個(gè)運(yùn)算只需要pytorch的一行代碼即可完成:torch.logsumexp()。
最后,我們還希望得到一條最佳路徑,于是需要維特比解碼得到。
因此,在這個(gè)類(lèi)中,我們需要定義不同的函數(shù)來(lái)實(shí)現(xiàn)不同的功能:
1. __init__必須定義,初始化參數(shù)
2.forward必須定義,前向傳播,得到損失值。這里面會(huì)調(diào)用其他函數(shù),用于計(jì)算損失。
3.計(jì)算正確路徑分?jǐn)?shù)的函數(shù)
4.計(jì)算所有路徑總分的函數(shù)
5.維特比解碼函數(shù)
6.能讓外接調(diào)用,得到最佳路徑的函數(shù)
注意:下面所有函數(shù),都在CRF類(lèi)里面,這里以分段的形式記錄。
1.定義CRF類(lèi),初始化相關(guān)參數(shù)
class CRF(nn.Module):def __init__(self, num_tags : int = 2, batch_first : bool = True) -> None:super(CRF, self).__init__()self.num_tags = num_tagsself.batch_first = batch_first# start到其他(不含end)的得分self.start_transitions = nn.Parameter(torch.empty(num_tags))# 其他(不含start)到end的得分self.end_transitions = nn.Parameter(torch.empty(num_tags))# 轉(zhuǎn)移分?jǐn)?shù)矩陣self.transitions = nn.Parameter(torch.empty((num_tags, num_tags)))self.reset_parameters()def reset_parameters(self):'''將初始化的分?jǐn)?shù)限定在-0.1到0.1之間'''init_range = 0.1nn.init.uniform_(self.start_transitions, -init_range, init_range)nn.init.uniform_(self.end_transitions, -init_range, init_range)nn.init.uniform_(self.transitions, -init_range, init_range)2.定義forward函數(shù)
forward函數(shù)所需要的其他函數(shù),后面補(bǔ)充。通過(guò)forward函數(shù)之后,返回的是我們所需要的損失值。
def forward(self, emissions: torch.Tensor,tags: torch.Tensor = None,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:self._validate(emissions, tags=tags, mask=mask)# reduction:損失值模式,是均值還是求和作為損失reduction = reduction.lower()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f"invalid reduction {reduction}")if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8)if self.batch_first:# 發(fā)射分?jǐn)?shù)形狀:(seq_length, batch_size, tag_num)emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# 計(jì)算正確標(biāo)簽序列的發(fā)射分?jǐn)?shù)和轉(zhuǎn)移分?jǐn)?shù)之和, shape: (batch_size, )numerator = self._cumputer_score(emissions=emissions, tags=tags, mask=mask)# 計(jì)算所有序列發(fā)射分?jǐn)?shù)和轉(zhuǎn)移分?jǐn)?shù)之和, shape: (batch_size, )denominator = self._compute_normalizer(emissions=emissions, mask=mask)# 二者相減, shape: (batch_size, )llh = denominator - numerator# 根據(jù)不同的設(shè)定返回不同形式的分?jǐn)?shù)if reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()if reduction == 'token_mean':return llh.sum() / mask.float().sum()3.forword調(diào)用的函數(shù):_validate
主要是用來(lái)確保所有輸入數(shù)據(jù)的維度應(yīng)該是我們所要求的維度。
def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f"emissions must have dimension of 3, got{emissions.dim()}")if emissions.size(2) != self.num_tags:raise ValueError(f"expected last dimission of emission is {self.num_tags},"f"got {emissions.size(2)}")if tags is not None:if emissions.shape[:2] != mask.shape:raise ValueError(f"the first two dimensions of mask and emissions must match"f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}")no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on.')4.forward調(diào)用的函數(shù):_computer_score
該函數(shù)用來(lái)計(jì)算最佳路徑的分?jǐn)?shù)。
def _computer_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# batch secondassert emissions.dim() == 3 and tags.dim() == 2assert emissions.shape[:2] == tags.shapeassert emissions.size(2) == self.num_tagsassert mask.shape == tags.shape# 每個(gè)mask,開(kāi)頭一定是1,否則相當(dāng)于句子就沒(méi)了。assert mask[0].all()seq_length, batch_size = tags.shapemask = mask.float()# start,轉(zhuǎn)移到其他所有標(biāo)簽的分?jǐn)?shù),不包含end# 根據(jù)實(shí)際的tag的開(kāi)頭的詞,得到從start到每句話(huà)開(kāi)頭的類(lèi)型的分?jǐn)?shù)。# 這里是start到第一個(gè)詞的轉(zhuǎn)移分?jǐn)?shù),shape: (batch_size,)score = self.start_transitions[tags[0]]# 接下來(lái)是預(yù)測(cè)的每句話(huà)的開(kāi)頭應(yīng)當(dāng)是什么tag,如果有3個(gè)tag,那么每個(gè)詞都會(huì)有對(duì)應(yīng)的三個(gè)分?jǐn)?shù),分別對(duì)應(yīng)每一個(gè)tag# 但是我們實(shí)際的tag是在tags[0]里面的,而預(yù)測(cè)的三個(gè)值,分?jǐn)?shù)不一定是多少# 比如實(shí)際的第一個(gè)詞tag是B,預(yù)測(cè)的BIO的三個(gè)分?jǐn)?shù)分別為:(0.1,0.5,04)# 那么我們把0.1這個(gè)分?jǐn)?shù)加上。這個(gè)就是發(fā)射分?jǐn)?shù),也就是預(yù)測(cè)的分?jǐn)?shù)。score += emissions[0, torch.arange(batch_size), tags[0]]# 至此,我們完成了從start轉(zhuǎn)移到第一個(gè)詞的轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù)# 接下來(lái)是每個(gè)詞到下一個(gè)詞的轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù),全加到一塊for i in range(1, seq_length):# 轉(zhuǎn)移分?jǐn)?shù)score += self.transitions[tags[i-1], tags[i]] * mask[i]# 發(fā)射分?jǐn)?shù)score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# 取到最后一個(gè)詞的tag# 使用的mask是形如:[1,1,1,1,0,0,0],后面的0是padding的,因此沒(méi)字了# 因此通過(guò)下面的方式,取到1的和,減去1,就是最后一個(gè)詞的索引了。seq_end = mask.long().sum(dim=0) - 1last_tag = tags[seq_end, torch.arange(batch_size)]# 最后一個(gè)詞轉(zhuǎn)移到end的分?jǐn)?shù)score += self.end_transitions[last_tag]return score5.forward調(diào)用的函數(shù):_compute_normalizer
這里計(jì)算所有路徑的分?jǐn)?shù)之和。并取一個(gè)logsumexp
def _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length = emissions.size(0)# emissions[0],因?yàn)榈谝粋€(gè)維度是句子長(zhǎng)度,因此emissions[0]就是每一個(gè)句子的開(kāi)頭的詞,對(duì)應(yīng)的發(fā)射分?jǐn)?shù)# 并且每一個(gè)分?jǐn)?shù)是有num_tags這么多。因此emissions[0]就是對(duì)所有開(kāi)頭的詞,對(duì)每一個(gè)標(biāo)簽預(yù)測(cè)的分?jǐn)?shù)。# 再加上start標(biāo)志到每一個(gè)標(biāo)簽的分?jǐn)?shù),就是一個(gè)整體的開(kāi)頭分?jǐn)?shù)之和。score = self.start_transitions + emissions[0]# 接下來(lái)把所有的轉(zhuǎn)移分?jǐn)?shù),發(fā)射分?jǐn)?shù)全部加起來(lái)。for i in range(1, seq_length):# 原來(lái)是(batch_size, num_tags), 現(xiàn)在是(batch_size, num_tags, 1)broadcast_score = score.unsqueeze(dim=2)# 對(duì)于第i個(gè)詞,原來(lái)是(batch_size, num_tags), 現(xiàn)在是(batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# 先把開(kāi)頭的分?jǐn)?shù)和轉(zhuǎn)移矩陣加起來(lái),便得到了開(kāi)頭的每一個(gè)tag,轉(zhuǎn)移到其他每一個(gè)tag的概率# 再把發(fā)射矩陣加上,便得到了該單詞的總分?jǐn)?shù)。其中會(huì)自動(dòng)使用broad cast機(jī)制next_score = broadcast_score + self.transitions + broadcast_emission# 對(duì)總分?jǐn)?shù),在第二個(gè)維度求一個(gè)對(duì)數(shù)域的分?jǐn)?shù)。第二個(gè)維度,也就是轉(zhuǎn)移矩陣的行# 我們求的是所有路徑的總分?jǐn)?shù),要對(duì)這個(gè)分?jǐn)?shù)求和。# 假如對(duì)第二個(gè)詞來(lái)說(shuō),可能由第一個(gè)詞的num_tags那么多的可能性過(guò)來(lái),那么就把所有的可能性加起來(lái)# 這樣得到的就是對(duì)于第二個(gè)詞來(lái)說(shuō)的總分?jǐn)?shù)。因此,把轉(zhuǎn)移矩陣的行,也就是前一個(gè)詞可能的tag,全部加起來(lái)即可# 也就是在第二個(gè)維度上求和。這樣就得到了總分?jǐn)?shù),我們對(duì)這個(gè)總分?jǐn)?shù)進(jìn)行對(duì)數(shù)域計(jì)算即可(取e,求和,取對(duì)數(shù))。next_score = torch.logsumexp(next_score, dim=1)# 通過(guò)mask,如果對(duì)應(yīng)的單詞位置有值,也就是我們需要這個(gè)分?jǐn)?shù),那么就使用next_score# 如果對(duì)應(yīng)的位置沒(méi)值,那么這個(gè)分?jǐn)?shù)不需要加上,就取原來(lái)的scorescore = torch.where(mask[i].unsqueeze(1), next_score, score)# 最后把單詞轉(zhuǎn)移到end的分?jǐn)?shù)加上score += self.end_transitions# 返回值取對(duì)數(shù)域的值,把所有的詞的分?jǐn)?shù)再求和一遍return torch.logsumexp(score, dim=1)6.forward調(diào)用的函數(shù):_viterbi_decode
維特比解碼,得到最佳路徑。
def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor) -> List[List[int]]:assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length, batch_size = mask.shapescore = self.start_transitions + emissions[0]history = []for i in range(1, seq_length):broadcast_score = score.unsqueeze(2)broadcast_emission = emissions[i].unsqueeze(1)next_score = broadcast_score + self.transitions + broadcast_emission# 在第一個(gè)維度上面求最大,消掉第一個(gè)維度,那么剩下的就是"到下一個(gè)類(lèi)型概率最大的那個(gè)類(lèi)型"# 這個(gè)max返回值有2個(gè),一個(gè)是求完最大值后的結(jié)果,形狀是(B, tag_num),一個(gè)是每個(gè)最大值所在的索引# 兩個(gè)返回結(jié)果形狀一致# 選最好的轉(zhuǎn)移分?jǐn)?shù)next_score, indices = next_score.max(dim=1)score = torch.where(mask[i].unsqueeze(1), next_score, score)# 上一個(gè)詞轉(zhuǎn)移到這個(gè)詞時(shí),分?jǐn)?shù)最高的那些值的索引history.append(indices)score += self.end_transitionsseq_ends = mask.long().sum(dim=0) - 1best_tags_list = []for idx in range(batch_size):# 取到分?jǐn)?shù)最高的標(biāo)簽,就是最后一個(gè)詞的標(biāo)簽的索引# 選最好的發(fā)射分?jǐn)?shù)_, best_last_tag = score[idx].max(dim=0)best_tags = [best_last_tag.item()]# seq_ends存了每個(gè)句子序列的最后一個(gè)詞的索引。for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(best_last_tag.item())best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list7.外部調(diào)用的函數(shù):decode
該函數(shù)調(diào)用了上面的維特比解碼,外部可通過(guò)model.decode調(diào)用,返回最佳路徑。
def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor]=None) -> List[List[int]]:self._validate(emissions=emissions, mask=mask)if mask is None:mask = emissions.new_ones(emissions.shape[:2],dtype=torch.uint8)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)return self._viterbi_decode(emissions, mask)總結(jié)
以上是生活随笔為你收集整理的KBQA-Bert学习记录-CRF模型的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 微信小程序中使用tabBar
- 下一篇: 华为rh2285 v1的装上独立显卡,并