SiameseNet(Learning Text Similarity with Siamese Recurrent Networks)
SiameseNet研究意義:
1、基于LSTM的孿生網絡結構,為后來研究打下良好的基礎,有很深遠的影響
2、提出了幾種文本增強的思路
注:
? ? ? ?孿生網絡指的是輸入是成對的,并且兩部分的網絡結構和參數都是一樣的,也就是只有一個網絡
? ? ? ?偽孿生網絡是指模型輸入相同,模型結構相同,但模型參數不共享
?
本文主要結構如下所示:
一、Abstract
? ? ? ? ?摘要部分主要講解本篇論文提出一個基于字符級別變長序列的雙向LSTM網絡結構,并且文中提出多種文本增強的思路,模型不僅能學到不同詞之間的語義差異性與語義不變性
二、Introduction
? ? ? ? ?主要介紹本文表示在自然語言處理中扮演者重要的角色,并且有很多的應用場景,并且說出了文本的語義表示,即不同的文本但是存在意思是一樣的,如: 12pm、noon、12.00h,最后介紹了本文中的模型在崗位分類任務上的效果。
三、Related Work
? ? ? ? ? ? ? ? ? 主要介紹相關的工作,包括神經網絡以及詞表示最近幾年在nlp任務中的發展,以及孿生網絡的發展和應用,最后介紹相同語義的不同表示方法歸一化問題
? ? ? ? ? ? ? ? ?
四、Siamese Recurrent Neural Network
? ? ? ? ? 主要講解孿生網絡的模型輸入數據、相關參數矩陣介紹、對比損失函數公式介紹
? ? ? ? ??
五、Experiments
? ? ? ? ? ? ?實驗部分主要首先講解了baseline模型n-gram matcher,然后介紹了文本數據增強的方法,主要有隨機替換字符和一定比例的拼寫錯誤、同義詞替換、添加多余信息可以增加魯棒性、以及人工反饋。
六、Discussion
? ? ? ? ? ? 最后一部分對本文進行總結,并提出了未來可以優化的方向? ? ? ? ? ?
?
創新點: 利用孿生網絡接收數據,雙向LSTM網絡提取特征,對比損失函數優化模型,并且提出多種文本數據增強的方法
啟發點: 循環神經網絡結構處理文本語義特征擁有比較好的表現
?
七、Code
# -*- coding: utf-8 -*-# @Time : 2021/2/12 下午6:57 # @Author : TaoWang # @Description : 定義網絡結構import torch import torch.nn as nnclass SiameseNet(torch.nn.Module):def __init__(self):super(SiameseNet, self).__init__()self.embedding = nn.Embedding(char_size, embedding_size)self.embedding.weight.data.copy_(torch.from_numpy(embedding))self.bi_lstm = nn.LSTM(embedding_size,lstm_hidden_size,num_layers=2,dropout=0.3,batch_first=True,bidirectional=True)self.dense = nn.Linear(linear_hidden_size, linear_hidden_size)self.dropout = nn.Dropout(0.3)def forward(self, a, b):""":param a::param b::return:"""embedding_a = self.embedding(a)embedding_b = self.embedding(b)lstm_a, _ = self.bi_lstm(embedding_a)lstm_b, _ = self.bi_lstm(embedding_b)avg_a = torch.mean(lstm_a, dim=1)avg_b = torch.mean(lstm_b, dim=1)out_a = self.dropout(torch.tanh(self.dense(avg_a)))out_b = self.dropout(torch.tanh(self.dense(avg_b)))cos_dis = torch.cosine_similarity(out_a, out_b, dim=1, eps=1e-9)return cos_dis# 定義模型損失函數 class ContrastiveLoss(nn.Module):def __init__(self):super(ContrastiveLoss, self).__init__()def forward(self, Ew, y):""":param Ew::param y::return:"""L_1 = 0.25 * (1.0 - Ew) * (1.0 - Ew)L_0 = torch.where(Ew < m * torch.ones_like(Ew), torch.full_like(Ew, 0), Ew) \* torch.where(Ew < m * torch.ones_like(Ew), torch.full_like(Ew, 0), Ew)loss = y * 1.0 * L_1 + (1 - y) * 1.0 * L_0return loss.sum()?
?
?
?
總結
以上是生活随笔為你收集整理的SiameseNet(Learning Text Similarity with Siamese Recurrent Networks)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: GGNN(Gated Graph Seq
- 下一篇: Comp-Agg (A Compare-