Transformer太大了,我要把它微调成RNN
文 | 煉丹學(xué)徒
編 | 小軼
從前車馬很慢,顯卡跑的也慢,一生只夠愛一個RNN。后來時代進步了,數(shù)據(jù)量和計算力闊綽了,堆疊起來的Transformer能夠在更深更寬的模型結(jié)構(gòu)里吃下去更多的數(shù)據(jù)。從19年的預(yù)訓(xùn)練浪潮開始,暴力美學(xué)興起,更深的Transformer更久的預(yù)訓(xùn)練更大的模型參數(shù)量,暴力出奇跡一個個NLP榜單被刷新,但誰又記得起來當(dāng)初Transformer論文里“解決RNN無法并行化訓(xùn)練問題”的追求效率的motivation呢?身在普通高校,手握2080Ti和Titan V,向著大廠的預(yù)訓(xùn)練模型望洋興嘆,我們開始懷念起當(dāng)初人人都訓(xùn)練得起的LSTM和GRU。那是精巧輕量的模型,那是人人都刷的起SOTA的時代。
今天這篇來自微軟的論文告訴我們,大廠里有一些研究員也還是愛我們的,Finetuning Pretrained Transformers into RNNs,在保持性能的情況下,將預(yù)訓(xùn)練好的Transformer模型微調(diào)到其RNN變體,極大地降低顯存使用和計算開銷。
論文題目:
Finetuning Pretrained Transformers into RNNs
論文鏈接:
https://arxiv.org/abs/2103.13076
Arxiv訪問慢的小伙伴也可以在 【夕小瑤的賣萌屋】訂閱號后臺回復(fù)關(guān)鍵詞 【0407】 下載論文PDF~
本文提出的模型名為 T2R,代表 Transformer to RNN 。轉(zhuǎn)換的過程為 swap-then-finetune ,即,對于一個預(yù)訓(xùn)練好的 Transformer 模型,我們將其 的注意力計算改為線性 的替換模塊,然后進行微調(diào)。可以預(yù)感到,其核心就在于如何用線性的子層對注意力層進行模擬。接下來,我們對其進行詳解。
概述
在2019年EMNLP論文 Transformer Disp [1] 中,作者提出:可以將注意力層的相似度計算()替換為核函數(shù)的分數(shù)。
ICML'20的另一工作Transformers are RNNs [2]則在此基礎(chǔ)上進一步優(yōu)化,提出了將的注意力計算替換為線性的模塊。
今天要講的 T2R 這篇文章是緊隨上面 ICML'20 這篇工作進行的。之前 Transformers are RNNs 的方法中,使用的核函數(shù)沒有參數(shù),不可訓(xùn)。而 T2R 把核函數(shù)里封裝了一個MLP變成可訓(xùn)練的。T2R原文的推導(dǎo)直接使用了 Transformers are RNNs 與 Transformer Disp 的結(jié)論,因而推導(dǎo)過程并不完整。我們今天也沿著T2R的思路進行講解,如果想要更深入了解 Transformer 轉(zhuǎn) RNN 領(lǐng)域的,可以閱讀下面兩篇論文:
[1] Tsai et al. Transformer Disp: A Unified Understanding of Transformer's Attention via the Lens of Kernel. EMNLP 2019
[2] Katharopoulos et al. Transformers are RNNs: Fast autoregressive transformers with linear attention. ICML 2020
Transformer開銷
Transformer 由多頭注意力層、前饋層、層歸一化層堆疊后組成。本篇論文中要替換的,就是其中的多頭注意力層。
在開始講解如何替換之前,我們還是先梳理一下傳統(tǒng)Transformer的多頭注意力層。整個計算過程可以總結(jié)如下圖所示:
▲傳統(tǒng)Transformer的多頭注意力層計算過程這張圖我們自下往上看。首先,我們將多頭注意力層的source隱狀態(tài)記作,target隱狀態(tài)記作。
如何理解此處的source和target:比如,在解碼器的編碼器-解碼器注意力層中,就是編碼器端的序列長度,就是解碼器端的長度。在自回歸推斷的解碼器自注意力層中,就是已生成序列(加上自己)的長度,等于1,指當(dāng)前要預(yù)測的這個字符。
從隱狀態(tài),我們通過線性變換得到。則,注意力層的輸出為:
其中, 操作 旨在計算和的相似度(這里劃重點!等一會兒就要對這個計算動手腳了!):
上述的多頭注意力的計算是我們熟知的。論文對其復(fù)雜度進行了分析。設(shè)多頭數(shù)為,每個頭的隱狀態(tài)長度,每個的隱狀態(tài)總長 ,則有如下結(jié)論:
特征計算:即由隱狀態(tài)計算得到的過程,復(fù)雜度分別為 , 和
注意力計算: 由 計算得到最終輸出的過程,復(fù)雜度為 ,與 的長度成平方關(guān)系。
推斷時的顯存:,與已經(jīng)解碼的長度線性相關(guān)。
注意力層的RNN替代方案
T2R的注意力層計算過程則如下圖所示:
首先,我們注意到原始的注意力計算中, 和 的相似度計算方式()需要先進行點乘,放縮后再進行指數(shù)運算,難以開展后續(xù)的近似優(yōu)化。所以這里的關(guān)鍵之處就在于,T2R把的相似度計算方案替換為核函數(shù)的乘積:
此處,和的參數(shù)都是通過一個單層MLP學(xué)習(xí)得到的。 是維矩陣,是維bias向量,即,T2R的相似度計算核函數(shù)將原本維的向量降到了維然后進行相似度計算。對于多頭計算中的每一個頭,他們的和是獨立學(xué)出來的。因此,T2R在每一層中,共增加了個可學(xué)習(xí)的參數(shù)(小于總參數(shù)量的2%)。
我們把新的相似度計算方法代入到注意力的輸出式中,得到:
記,,則:
而根據(jù) Transformers are RNNs [2] 的結(jié)論,此處的可以視作RNN遞歸的隱狀態(tài)。比如,在解碼器端做自回歸生成時,每個詞向它前文的單詞進行注意力計算來預(yù)測下一個詞,和可以被定義為遞歸的隱狀態(tài):
注意到我們主要討論的函數(shù)是針對來計算相似度的,而是由喂入該層的隱狀態(tài)線性變化得到的。為了加速推斷速度,具體實現(xiàn)中把和代入,得到從隱狀態(tài),直接線性變換得到的結(jié)果,從而在推斷的時候不需要計算,而從隱狀態(tài)直接計算得到相似度的值,即:
其中,
此時的開銷:
特征計算:我們記輸出維的特征向量,則生成的復(fù)雜度為 , 和
注意力計算: 由計算得到最終輸出的過程,假設(shè)k<<M,N,此時復(fù)雜度為,與的長度成線性關(guān)系。
推斷時的顯存:假設(shè)k<<M,則占用顯存,為常數(shù)。
Transformer和T2R對比
講到這里,我們再對比一下傳統(tǒng)Transformer和T2R的差異:
特征計算:計算不變,計算由, 降為,
注意力計算: 由降為,平方->線性。
推斷時的顯存:由降為,線性->常數(shù)。
實驗
數(shù)據(jù)集的效果
T2R主要使用ELU和RFA作為baseline進行比較。ELU和RFA為此前的另外兩篇使用核函數(shù)轉(zhuǎn)Transformer為RNN工作。因為ELU和RFA的核函數(shù)都是不可訓(xùn)練的,所以無法取代預(yù)訓(xùn)練好的模型里的注意力層進行功能上的替換和擬合。
首先,T2R在語言模型上開展了實驗。數(shù)據(jù)集使用WikiText-103,評測指標使用困惑度 perplexity 。發(fā)現(xiàn)T2R因為在核函數(shù)中放置了可訓(xùn)練的MLP,在加載預(yù)訓(xùn)練模型時獲得更大的收益。
此外,T2R在翻譯任務(wù)上開展實驗,使用數(shù)據(jù)集 WMT14 EN-DE,WMT14 EN-FR 和 WMT17 ZH-EN。研究員們發(fā)現(xiàn)雖然隨機初始化時,T2R弱于另外兩個baseline,但是加載預(yù)訓(xùn)練后反超另外兩個baseline。
生成時的加速和顯存節(jié)省
研究員發(fā)現(xiàn) T2R 比另外兩個模型的推斷速度更快(如下左圖所示),因為使用了更小的特征維度,以及更快的特征計算方法。對于推斷時的顯存占用,Transformer 隨著輸出序列的增長而線性增加,轉(zhuǎn)為 RNN 結(jié)構(gòu)的模型則保持常數(shù)(如下右圖所示)。
消融實驗
隨著核函數(shù)輸出特征尺寸的增大,其效果也更加接近Transformer。相比于之前的工作,T2R 可以通過控制特征尺寸從而在效果和速度間權(quán)衡。
小結(jié)
本文提出的T2R,在 Transformers are RNNs 的基礎(chǔ)上,將無參數(shù)的核函數(shù)封裝為 MLP 加激活函數(shù),從而可訓(xùn)練。在此基礎(chǔ)上,T2R 替換掉預(yù)訓(xùn)練 Transformer 的注意力層,從而降低了計算消耗和顯存使用,并且得到和原預(yù)訓(xùn)練模型相似的結(jié)果。
后臺回復(fù)關(guān)鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
后臺回復(fù)關(guān)鍵詞【頂會】
獲取ACL、CIKM等各大頂會論文集!
總結(jié)
以上是生活随笔為你收集整理的Transformer太大了,我要把它微调成RNN的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 12种NumpyPandas高效技巧
- 下一篇: 95后CV工程师晒出工资单:狠补了这个,