One-shot Learning with Memory-Augmented Neural Networks
摘要
盡管深度學習應用領域最近取得了較大的進展,但是小樣本學習的挑戰是一直存在的,傳統的基于梯度的網絡需要大量的數據去學習,通常需要經過大量廣泛的迭代訓練。當給模型輸入新數據時,模型必須低效的重新學習其參數從而充分的融入新的信息,并不會造成較大的干擾影響。具有增強記憶能力的網絡結構,例如NTMs具有快速編碼新信息的能力,因此能消除傳統模型的缺點。這里,我們證明了記憶增強神經網絡(memory-augmented neural network)具有快速吸收新數據知識的能力,并且能利用這些吸收了的數據,在少量樣本的基礎上做出準確的預測。
我們也介紹了一個訪問外部記憶存儲器的方法,該方法關注于記憶存儲器的內容,這和之前提出的使用基于記憶存儲器位置的聚焦機制的方法不同。
當前深度學習的成功取決于基于梯度的優化算法應用于高容量模型(神經元數量)的能力。這種方法在許多以原始感官為輸入的大型監督任務上已經取得了非常好的結果,例如圖像分類、語音識別、游戲等。值得注意的是,這些任務上的表現通常是在大型數據集上經過廣泛的增量式訓練來評估得出的。相反,許多興趣問題(many problems of interest)需要從少量的數據中快速推斷出結果。在one-shot learning的記心中,單一的觀察結果會導致行為的突然轉變。
這種靈活的適應是人類學習中一個重要的方面,從發動機的控制到抽象概念的獲取都得到了表現。根據少量信息的推斷生成新的行為,比如推斷一個只在上下文中出現過一兩次的單詞的全局適用性,這是超出當代智能能力的。這對深度學習提出了嚴峻的挑戰,只有在少數樣本逐一呈現的情況下才有一種簡單的基于梯度的解決方案:從目前可用數據中完全重新學習參數。但是這種方法往往會導致不良學習和災難性干擾,這時非參數方法往往被認為更合適。
可是先前的工作提出一種從稀疏數據中學習的策略,并取決于元學習的概念。雖然meta-learning。雖然meta-learning術語已經被用在很多領域。元學習通常考慮為學習兩種水平的場景,并且每個水平和不同的時間尺度有關。快速學習通常出現在一個任務內,例如在特定的數據集中進行準確分類。這種學習是由在任務中逐漸積累的知識來指導的,這些知識捕獲了任務結構在目標域中的變化方式或變化規律。考慮到這種結構的兩層形式,因此也被叫做learning to learn。
已經提出的具有記憶能力的神經網絡能夠證明確實能夠進行元學習。這些網絡能夠通過權重更新改變偏置的值,并能通過快速學習記憶存儲中的緩存表示(cache representations in memory stores)來調整輸出結果。例如用lstms當做元學習的網絡能根據少量的數據樣本就能快速學習到之前沒有見過的二次函數。
具有記憶能力的神經網絡給元學習在深度網絡中提供了一種可行的方法。但是使用非結構化循環網絡結構的內在記憶單元這種特定的策略不可能擴展到每個新任務都需要快速編碼吸收大量新信息的場景。一個可擴展的解決方案必須有以下必要的要求:首先,信息必須穩定的表現形式存儲在記憶存儲器中 (以便在需要時可以可靠地訪問),并且記憶中的元素可尋址(以便可以選擇性的訪問信息);其次,參數的數量不應該和記憶存儲器的大小有關聯。標準的具有記憶的結構例如LSTMs并沒有這兩種特性。然而最近的架構中如神經圖靈機NTMS和記憶網絡滿足了這兩個特點的要求。因此在文中我們從一個高容量的記憶增強神經網絡的角度重新考慮了元學習的問題和設置(setting),(注:這里MANN指配備外部記憶的網絡,而不是其他內部記憶單元的架構如LSTM)。
我們的方法結合最有利的兩部分:通過梯度下降慢慢的從原數據中獲取有用表示(representations)的抽象方法;通過外部記憶存儲模塊在一次表示之后(after a single presentation)快速吸收沒有見到過的知識。這種結合使元學習更加健壯,并擴展了可以有效應用深度學習的問題范圍。
通常我們選擇一個參數θ在數據集D上去最小化損失函數L。可是對于元學習來說,我們選擇參數來降低數據集分布P(D)中的期望損失。
要做到這一點,正確的任務設置至關重要。在我們的任務設置中,一個任務或者插曲片段(a task, or episode)涉及一些數據集D的表示。Yt既是一個目標,也是以時間偏移的方式與xt一塊作為輸入。這個網絡的目的是在給定的時間戳t上為xt輸出正確的標簽。重要的是,標簽是從數據集中混洗得到的,這樣能夠防止網絡緩慢的學習樣本和類的綁定關系來更新權重。相反的的是,網絡必須將數據樣本存在內存中,直到下一個時間戳到達,正確的類標簽被展示出來,在這之后,樣本和類標簽的對應關系能被發現并且存儲這種關系信息供以后使用。因此,對于給定的一段情節(episode),理想的表現會涉及到對第一個類的標簽值(the first presentation of a class,我理解為類的值)的隨機猜測,(因為標簽被混洗了,不能根據之前的情節推斷出正確的標簽),并且之后使用記憶存儲器來實現準確率的完美預測。最終,這個系統目標是對預測分布p進行建模,在每一個時間步引起相應的損失。
這個任務結構包含可利用的元知識:元學習的模型學習將數據表示綁定到其對應的正確標簽,而不管數據表示或標簽的實際內容如何,并且將采用一般方案將這些綁定表示(bound representations)映射到正確的類或用于預測的函數值。
3.1神經圖靈機
神經圖靈機是MANN一種完全不同的實現。他包括一個控制器,例如一個前饋網絡或者LSTM,這和一個使用一些讀寫頭的額外記憶模塊相互影響。圖靈機中的記憶模塊的記憶單元編碼和索引都是很快的,向量表示可能在每個時間步驟被放入或取出內存。這種能力使NTM稱為元學習和短時預測完美的候選者,因為它既能通過慢的權重更新實現長期存儲,并且通過額外記憶模塊實現短期存儲。如果NTM能夠學習一種通用策略來將各種表示(representations,這里指內存單元中記錄的信息)類型放入記憶單元中,并且能夠學習之后如何使用這些表示來做預測,那么他可能利用他的速度來對僅見過一次的數據進行準確預測。
?????? 我們模型中的控制器要么使用LSTMs或者前饋網絡。控制器使用讀寫頭與外部存儲器模塊交互,讀寫頭分別用于從存儲器中檢索表示(representations)或將它們放入存儲器中。給定一些輸入xi,控制器生成一個鍵值kt,這個鍵值被存入記憶矩陣Mt的一行,或者被用于從一行中索引一個特定的記憶單元i,Mt(i),當索引一個記憶單元Mt的時候,會使用余弦相似度。
用于去產生讀權重向量Wrt,根據以下公式計算得到
一個記憶單元rt,通過使用權重向量進行索引:
這個記憶單元的內容被控制器作為一個分類器的輸入淚如softmax層的輸入,或者作為下一個控制器狀態的額外輸入。
3.2、最少或最近使用的記憶信息
??? 在之前NTM的例子中,記憶信息通過內容或者位置被索引。基于位置的索引常常被用于迭代更新的步驟,就像沿著磁帶跑一樣,也回用于在記憶信息上的長距離跳躍。這種方法對于基于序列預測的任務是有優勢的,可是這種方式對于強調獨立于序列之外的信息的任務并不是最優的。因此,在我們的模型中,包含一個新設計的讀取記憶信息的模式叫做LRUA。
??? LRUA模型是一個純粹的基于內容的記憶讀寫方式,記憶信息要么被寫到斤少使用的記憶模塊的位置或者最近使用的記憶模塊的位置。這個模塊看重有關信息的準確編碼(吸收提取數據的知識),并且是完全的基于內容的索引。新的信息被寫入到很少使用的位置或者寫入到最后使用的位置,保存最近編碼的信息,這可以用更加新的、可能更相關的信息更新的記憶信息。這兩種方式的不同在于先前的讀參數和使用參數(usage weights wtu),這些使用參數通過衰減參數逐步更新參數值,
這里,gama是衰減參數,讀向量參數由前邊計算出來,最少使用的權重能通過用戶參數計算出來,其中m(v,n)表示前n個
n是讀記憶的數目,寫參數向量(write weights wtw)由以下方式計算得到:
σ(·) 是sigmoid函數,
記憶信息能夠被寫到標記為零記憶槽,或者之前被使用過的槽(slot),如果是之前使用過的槽,那么就是最少被使用的槽,并且原本槽里的記憶信息會被刪除。
總結
以上是生活随笔為你收集整理的One-shot Learning with Memory-Augmented Neural Networks的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《调笑令·周年》
- 下一篇: 漫话:如何给女朋友解释为什么一到年底,部