从语言模型到Seq2Seq:Transformer如戏,全靠Mask
作者丨蘇劍林
單位丨追一科技
研究方向丨NLP,神經網絡
個人主頁丨kexue.fm
相信近一年來(尤其是近半年來),大家都能很頻繁地看到各種 Transformer 相關工作(比如 BERT、GPT、XLNet 等等)的報導,連同各種基礎評測任務的評測指標不斷被刷新。同時,也有很多相關的博客、專欄等對這些模型做科普和解讀。
俗話說,“外行看熱鬧,內行看門道”,我們不僅要在“是什么”這個層面去理解這些工作,我們還需要思考“為什么”。這個“為什么”不僅僅是“為什么要這樣做”,還包括“為什么可以這樣做”。比如,在談到 XLNet 的亂序語言模型時,我們或許已經從諸多介紹中明白了亂序語言模型的好處,那不妨更進一步思考一下:
為什么 Transformer 可以實現亂序語言模型?是怎么實現的?RNN 可以實現嗎?
本文從對 Attention 矩陣進行 Mask 的角度,來分析為什么眾多 Transformer 模型可以玩得如此“出彩”的基本原因,正如標題所述“Transformer 如戲,全靠 Mask”,這是各種花式 Transformer 模型的重要“門道”之一。
讀完本文,你或許可以了解到:
1. Attention 矩陣的 Mask 方式與各種預訓練方案的關系;
2. 直接利用預訓練的 BERT 模型來做 Seq2Seq 任務。
背景
自
總的來說,這些以預訓練為基礎的工作層出不窮,有種琳瑯滿目的感覺。甚至一定程度上來說,如果你還沒有微調過 BERT ,那已經算是落后于主流的 NLP 技術了。?
花式預訓練
眾所周知,傳統的模型預訓練手段就是語言模型,比如 ELMo [1] 模型就是以 BiLSTM 為基礎架構、用兩個方向的語言模型分別預訓練兩個方向的 LSTM 的,后面的 OpenAI 的 GPT、GPT-2?[2] 也是堅定不移地堅持著用祖傳的(標準的、單向的)語言模型來預訓練。?
然而,還有更多花樣的預訓練玩法。比如 BERT [3] 就用了稱之為“掩碼語言模型(Masked Language Model)”的方式來預訓練,不過這只是普通語言模型的一種變體;還有 XLNet [4]則提出了更徹底的“Permutation Language Modeling”,我們可以稱之為“亂序語言模型”;還有 UNILM [5] 模型,直接用單個 BERT 的架構做 Seq2Seq,你可以將它作為一種預訓練手段,又或者干脆就用它來做 Seq2Seq 任務。
如此花樣百出,讓我們不禁疑問:為什么剛好在 Transformer 流行的時代,才出現這種各種大型預訓練模型“百花齊放,百家爭鳴”的現象?
Transformer專屬
事實上,除了單向語言模型及其簡單變體掩碼語言模型之外,UNILM 的 Seq2Seq 預訓練、XLNet 的亂序語言模型預訓練,基本可以說是專為 Transformer 架構定制的。說白了,如果是 RNN 架構,根本就不能用亂序語言模型的方式來預訓練,至于 Seq2Seq 的預訓練方式,則必須同時引入兩個模型(encoder 和 decoder),而無法像 Transformer 架構一樣,可以一個模型搞定。
這其中的奧妙主要在 Attention 矩陣之上。Attention 實際上相當于將輸入兩兩地算相似度,這構成了一個大小的相似度矩陣(即 Attention 矩陣,n 是句子長度,本文的 Attention 均指 Self Attention),這意味著它的空間占用量是量級,相比之下,RNN 模型、CNN 模型只不過是 ?(n),所以實際上 Attention 通常更耗顯存。
然而,有弊也有利,更大的空間占用也意味著擁有了更多的可能性,我們可以通過往這個級別的 Attention 矩陣加入各種先驗約束,使得它可以做更靈活的任務。說白了,也就只有純 Attention 的模型,才有那么大的“容量”去承載那么多的“花樣”。
而加入先驗約束的方式,就是對 Attention 矩陣進行不同形式的 Mask,這便是本文要關注的焦點。
分析
在
這里的,分別代表 query、key、value 的向量序列,其中我們可以認為 key 和 value 是一一對應的,而則是將 query、key 的向量兩兩做內積,然后用 softmax?歸一化,就得到一個的 Attention 矩陣,它描述的就是 query 和 key 之間任意兩個元素的關聯強度,后面我們要講的故事,都是在這個 Attention 矩陣上下功夫。最后再與 V 相乘,相當于按照這個關聯強度將 V 的各個向量加權求和,最終輸出一個的向量序列。?
目前最常用的 Attention 方式當數 Self Attention,即 Q, K, V 都是同一個向量序列經過線性變換而來的,而 Transformer 則是 Self Attention 跟 Position-Wise 全連接層(相當于 kernel size 為 1 的一維卷積)的組合。所以,Transformer 就是基于 Attention 的向量序列到向量序列的變換。?
在本節中,我們將會比較詳細地分析 Attention 矩陣的 Mask 方式,這分別對應單向語言模型、亂序語言模型、Seq2Seq 的實現原理。
單向語言模型
語言模型可以說是一個無條件的文本生成模型,如果讀者還不了解文本生成模型,可以自行查閱相關資料并配合
我們一般說的“語言模型”,就是指單向的(更狹義的只是指正向的)語言模型。語言模型的關鍵點是要防止看到“未來信息”。如上式,預測 x1 的時候,是沒有任何外部輸入的;而預測 x2 的時候,只能輸入 x1,預測 x3 的時候,只能輸入 x1,x2;依此類推。
▲?單向語言模型圖示。每預測一個token,只依賴于前面的token。RNN 模型是天然適合做語言模型的,因為它本身就是遞歸的運算;如果用 CNN 來做的話,則需要對卷積核進行 Mask,即需要將卷積核對應右邊的部分置零。如果是 Transformer 呢?那需要一個下三角矩陣形式的 Attention 矩陣:
▲?單向(正向)語言模型的Mask方式
如圖所示,Attention 矩陣的每一行事實上代表著輸出,而每一列代表著輸入,而 Attention 矩陣就表示輸出和輸入的關聯。假定白色方格都代表 0,那么第 1 行表示“北”只能跟起始標記 <s> 相關了,而第 2 行就表示“京”只能跟起始標記 <s> 和“北”相關了,依此類推。
所以,只需要在 Transformer 的 Attention 矩陣中引入下三角形形式的 Mask,并將輸入輸出錯開一位訓練,就可以實現單向語言模型了。至于 Mask 的實現方式,可以參考“讓Keras更酷一些!”:層中層與mask的 Mask 一節。
亂序語言模型
亂序語言模型是 XLNet 提出來的概念,它主要用于 XLNet 的預訓練上。說到 XLNet,我覺得它的亂序語言模型這種預訓練方式是很有意思的,但是我并不喜歡它將基本架構換成了 Transformer-XL。我覺得誰有資源可以試試“BERT+亂序語言語言模型預訓練”的組合,或許會有意外的發現。?
亂序語言模型跟語言模型一樣,都是做條件概率分解,但是亂序語言模型的分解順序是隨機的:
總之, x1, x2, … , xn 任意一種“出場順序”都有可能。原則上來說,每一種順序都對應著一個模型,所以原則上就有 n! 個語言模型。而基于 Transformer 的模型,則可以將這所有順序都做到一個模型中去!?
那怎么做到這一點呢?還是以“北京歡迎你”的生成為例,假設隨機的一種生成順序為“<s> → 迎 → 京 → 你 → 歡 → 北 → <e>”,那么我們只需要用下圖中第二個子圖的方式去 Mask 掉 Attention 矩陣,就可以達到目的了:
跟前面的單向語言模型類似,第 4 行只有一個藍色格,表示“迎”只能跟起始標記 <s> 相關,而第 2 行有兩個藍色格,表示“京”只能跟起始標記 <s> 和“迎”相關,依此類推。直觀來看,這就像是把單向語言模型的下三角形式的 Mask“打亂”了。?
也就是說,實現一種順序的語言模型,就相當于將原來的下三角形式的 Mask 以某種方式打亂。正因為 Attention 提供了這樣的一個 n × n 的 Attention 矩陣,我們才有足夠多的自由度去以不同的方式去 Mask 這個矩陣,從而實現多樣化的效果。?
說到這里,讀者可能會有一個實現上的疑問:打亂后的 Mask 似乎沒看出什么規律呀,難道每次都要隨機生成一個這樣的似乎沒有什么明顯概率的 Mask 矩陣?
事實上有一種更簡單的、數學上等效的訓練方案。這個訓練方案源于純 Attention 的模型本質上是一個無序的模型,它里邊的詞序實際上是通過 Position Embedding 加上去的。也就是說,我們輸入的不僅只有 token 本身,還包括 token 所在的位置 id;再換言之,你覺得你是輸入了序列“[北, 京, 歡, 迎, 你]”,實際上你輸入的是集合“{(北, 1), (京, 2), (歡, 3), (迎, 4), (你, 5)}”。
▲?重新排序,使得正向語言模型就可以實現亂序語言模型
既然只是一個集合,跟順序無關,那么我們完全可以換一種順序輸入,比如剛才的“<s> → 迎 → 京 → 你 → 歡 → 北 → <e>”,我們可以按“(迎, 4), (京, 2), (你, 5), (歡, 3), (北, 1)”的順序輸入,也就是說將 token 打亂為“迎,京,你,歡,北”輸入到 Transformer 中,但是第 1 個 token 的 position 就不是 1 了,而是 4;依此類推。這樣換過來之后,Mask 矩陣可以恢復為下三角矩陣,所以只需要在輸入層面打亂即可,這樣操作起來就更簡單了。
Seq2Seq
現在到我們的“重頭戲”了:將 BERT 等 Transformer 架構跟 Seq2Seq 結合起來。為什么說重頭戲呢?因為原則上來說,任何 NLP 問題都可以轉化為 Seq2Seq 來做,它是一個真正意義上的萬能模型。所以如果能夠做到 Seq2Seq,理論上就可以實現任意任務了。?
將 BERT 與 Seq2Seq 結合的比較知名的工作有兩個:MASS [6] 和 UNILM [5],兩者都是微軟的工作,兩者還都在同一個月發的。其中 MASS 還是普通的 Seq2Seq 架構,分別用 BERT 類似的 Transformer 模型來做 encoder 和 decoder,它的主要貢獻就是提供了一種 Seq2Seq 思想的預訓練方案。
真正有意思的是 UNILM,它提供了一種很優雅的方式,能夠讓我們直接用單個 BERT 模型就可以做 Seq2Seq 任務,而不用區分 encoder 和 decoder。而實現這一點幾乎不費吹灰之力——只需要一個特別的 Mask。
插曲:事實的順序是筆者前兩周自己獨立地想到了用單個 BERT 模型做 Seq2Seq 的思路,然后去找資料發現這個思路已經被做了,正是 UNILM。
UNILM 直接將 Seq2Seq 當成句子補全來做。假如輸入是“你想吃啥”,目標句子是“白切雞”,那 UNILM 將這兩個句子拼成一個:[CLS] 你 想 吃 啥 [SEP] 白 切 雞 [SEP]。經過這樣轉化之后,最簡單的方案就是訓練一個語言模型,然后輸入“[CLS] 你 想 吃 啥 [SEP]”來逐字預測“白 切 雞”,直到出現“[SEP]”為止,即如下面的左圖:
不過左圖只是最樸素的方案,它把“你想吃啥”也加入了預測范圍了(導致它這部分的 Attention 是單向的,即對應部分的 Mask 矩陣是下三角),事實上這是不必要的,屬于額外的約束。真正要預測的只是“白切雞”這部分,所以我們可以把“你想吃啥”這部分的 Mask 去掉,得到上面的右圖的 Mask。?這樣一來,輸入部分的 Attention 是雙向的,輸出部分的 Attention 是單向,滿足 Seq2Seq 的要求,而且沒有額外約束。這便是 UNILM 里邊提供的用單個 BERT 模型就可以完成 Seq2Seq 任務的思路,只要添加上述形狀的 Mask,而不需要修改模型架構,并且還可以直接沿用 BERT 的 Masked Language Model 預訓練權重,收斂更快。這符合“一 BERT 在手,天下我有”的萬用模型的初衷,個人認為這是非常優雅的方案。
▲?UNILM做Seq2Seq模型圖示。輸入部分內部可做雙向Attention,輸出部分只做單向Attention。
實驗
事實上,上述的這些 Mask 方案,基本上都已經被集成在筆者寫的 bert4keras [7],讀者可以直接用 bert4keras 加載 BERT 的預訓練權重,并且調用上述 Mask 方案來做相應的任務。下面,我們給出一個利用 UNILM 的思路做一個快速收斂的 Seq2Seq 模型的例子。?
代碼開源
這次代碼的測試任務依然是之前的標題生成,代碼調整自[8] 的原始數據集,讀者可以自行下載數據集和源碼測試復現。?
詳細請看:
https://github.com/bojone/bert4keras/blob/master/examples/task_seq2seq.py?這個效果能有多好呢?經過實驗,在標題生成的任務上,只要 7000 個 iteration,就已經能生成基本可讀的標題了。相應地,以前用 LSTM 做的時候,大概需要多 10 倍的 iteration 才有同樣的效果。
▲?只需要7000步的訓練,就可以得到基本可讀的生成結果
簡單說明
下面對代碼的關鍵部分做簡要說明。?
首先,輸入格式還是以 token_id 和 segment_id 輸入,比如:
tokens?=?['[ClS]',?u'你',?u'想',?u'吃',?u'啥',?'[SEP]',?u'白',?u'切',?u'雞',?'[SEP]'] token_ids?=?[token_dict[t]?for?t?in?tokens] segment_ids?=?[0,?0,?0,?0,?0,?0,?1,?1,?1,?1]segment_ids 用來區分輸入句子和目標句子,0 對應的為輸入句子,1 對應的為目標句子,只需要自帶的 tokenizer.encode 就可以生成這種 token_id 和 segment_id 了。
至于搭建模型,就只有寥寥幾行:
model?=?load_pretrained_model(config_path,checkpoint_path,seq2seq=True,keep_words=keep_words )model.summary()y_in?=?model.input[0][:,?1:]?#?目標tokens y_mask?=?model.input[1][:,?1:] y?=?model.output[:,?:-1]?#?預測tokens,預測與目標錯開一位#?交叉熵作為loss,并mask掉輸入部分的預測 y?=?model.output[:,?:-1]?#?預測tokens,預測與目標錯開一位 cross_entropy?=?K.sparse_categorical_crossentropy(y_in,?y) cross_entropy?=?K.sum(cross_entropy?*?y_mask)?/?K.sum(y_mask)注意 load_pretrained_model 中只要設置 seq2seq=True,就會自動加載 BERT 的 MLM 部分,并且傳入對應的 Mask,剩下就只需要把 loss 寫好就行了。另外還有一個 keep_words,這個是用來精簡 Embedding 層用的,對于中文 BERT 來說,總的 tokens 大概有 2 萬個,這意味著最后預測生成的 token 時是一個 2 萬分類問題。
但事實上這大多數 tokens 都不會被使用到,因此這 2 萬分類浪費了不少計算量。于是這里提供了一個選項,我們可以自行統計一個字表,然后傳入對應的 id,只保留這部分 token,這樣就可以降低計算量了(精簡后一般只有 5000 個左右)。?
剩下的就是通過 beam search 來解碼等步驟了,這與一般的 Seq2Seq 無異,不再贅述,大家看
總結
本文相對系統地總結了 Transformer 中 Attention 矩陣的 Mask 技巧,并且給出了用 UNILM 方案來做 Seq2Seq 的實現。對于同語言的 Seq2Seq 的文本生成任務來說,采用 UNILM 的思路加載 BERT 的 MLM 預訓練權重,能夠有效、快速地實現并提升生成效果,值得一試。
相關鏈接
[1]?https://arxiv.org/abs/1802.05365[2]?https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf[3] https://arxiv.org/abs/1810.04805[4]?https://arxiv.org/abs/1906.08237[5]?https://arxiv.org/abs/1905.03197[6]?https://arxiv.org/abs/1905.02450[7]?https://kexue.fm/archives/6915[8] http://thuctc.thunlp.org/#中文文本分類數據集THUCNews
點擊以下標題查看作者其他文章:?
基于DGCNN和概率圖的輕量級信息抽取模型
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
??來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
?
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點擊 |?閱讀原文?| 查看作者博客
總結
以上是生活随笔為你收集整理的从语言模型到Seq2Seq:Transformer如戏,全靠Mask的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 彩票店一年能赚多少?
- 下一篇: 浦发银行天添盈1号和2号的区别