谷歌Deep Bootstrap Framework:在线优化角度理解神经网络
The Deep Bootstrap Framework: Good Online Learners are Good Offline Generalizers(ICLR21)
一元@煉丹筆記理解深度學習的泛化性是一個尚未解決的基本問題。為什么在有限的訓練數據集上優化模型能在一個hold-out的測試集中取得良好的性能?這一問題在機器學習中已經被研究了將近50多年。現在存在非常多數學工具可以幫助研究人員理解某些模型中的泛化性能。但是很不幸的是,大多數現有理論在應用于現代深網絡時都失敗了——它們在現實環境中既空洞又不具有預測性。
對于過長度參數化的模型,理論與實踐之間的差距是巨大的,對于理論上有能力過擬合其訓練集的模型,但在實踐中卻往往沒有。
我們提出了一個新的框架,通過將泛化性與在線優化領域相結合來解決這個問題。在一個典型的設置中,一個模型在一組有限的樣本上訓練,這些樣本被多個epoch重復使用。但在在線優化中,模型可以訪問無限的樣本流,并且可以在處理該樣本流時進行迭代更新。
我們發現在無限數據上快速訓練的模型與在有限數據上訓練的模型具有相同的泛化能力。這種聯系為實踐中的設計選擇帶來了新的視角,并為從理論角度理解泛化奠定了路線圖。
Deep-Bootstrap框架的核心思想是將存在有限訓練數據的現實世界與存在無限數據的“理想世界”進行比較。我們將其定義為:
- Real World(N,T): 在某個分布中的N個訓練樣本上訓練模型,對于T個minibatch隨機梯度下降(SGD)步,在多個epoch上重復使用相同的N個樣本。這相當于在經驗損失(訓練數據損失)上運行SGD,屬于監督學習中的標準訓練過程。
- Ideal World(T): 在T步中訓練相同的模型,但是在每個SGD步中使用來自分布的全新樣本。也就是說,我們運行完全相同的訓練代碼(相同的優化器、學習速率、batch-size大小等),但在每個epoch中隨機采樣一個新的訓練集,而不是重用樣本。在這個理想的世界環境中,有一個有效的無限“訓練集”,訓練誤差和測試誤差之間沒有區別。
在先驗上,人們可能會認為現實世界和理想世界彼此無關,因為在現實世界中,模型看到的是有限數量的分布示例,而在理想世界中,模型看到的是整個分布。但在實際應用中,我們發現真實模型和理想模型實際上存在著相似的檢驗誤差。
為了量化這一觀察結果,我們通過創建一個新的數據集(我們稱之為CIFAR-5m)來模擬一個理想的世界環境。我們在CIFAR-10上訓練了一個生成模型,然后用它生成了約600萬張圖像。選擇數據集的規模是為了確保從模型的角度來看它“實際上是無限的”,這樣模型就不會對相同的數據進行重采樣。也就是說,在理想世界中,模型看到的是一組全新的樣本。
下圖顯示了幾種模型的測試誤差,比較了它們在真實環境(即重復使用的數據)和理想環境(“新”數據)中接受CIFAR-5m數據訓練時的性能。藍色實線顯示了現實世界中的ResNet模型,該模型使用標準CIFAR-10超參數在50K樣本上訓練100個epoch。藍色虛線顯示了理想世界中的相應模型,在一次過程中對500萬個樣本進行了訓練。令人驚訝的是,這些世界有著非常相似的測試錯誤——在某種意義上,模型“不在乎”它看到的是重復使用的樣本還是新的樣本。
這也適用于其它的架構,例如MLP、Vision Transformer,以及架構、優化器、數據分布和樣本大小的許多其他設置。這些實驗為泛化提供了一個新的視角:快速優化(在無限數據上)和良好的泛化(在有限數據上)模型。例如,ResNet模型比MLP模型在有限數據上的泛化效果更好,但這“是因為”即使在無限數據上,它的優化速度也更快。
我們核心的觀察結果是,真實世界和理想世界的模型在測試誤差中始終保持接近,直到真實世界收斂(<1%的訓練誤差)。因此,人們可以通過研究模型在理想世界中的相應行為來研究現實世界中的模型。
這也意味著模型的泛化可以從兩個框架下的優化性能來理解:
- 在線優化:理想世界測試誤差減少的速度有多快;
- 離線優化:真實世界的訓練誤差收斂速度有多快;
因此,為了研究泛化,我們可以等價地研究上述兩個術語,這在概念上可能更簡單,因為它們只涉及優化問題。基于這一觀察,好的模型和訓練過程是:
- 在理想世界中快速優化;
- 在現實世界中不會太快地優化模型;
深度學習中的所有設計選擇都可以通過它們對這兩個terms的影響來看待。例如,一些進展,如卷積,skpi連接和預訓練主要通過加速理想世界的優化來進行幫助,而其它的進步,如正則化和數據增強,則主要通過減速現實世界的優化來幫助。
研究人員可以使用Deep-Bootstrap框架來研究和指導深度學習中的設計選擇。其原理是:當一個人在現實世界中做出影響泛化的改變(結構、學習率等),他應該考慮它對:
- 測試誤差的理想世界優化(越快越好);
- 訓練誤差的現實世界優化(越慢越好)。
例如,在實踐中經常使用預訓練來幫助在小數據區域中的模型泛化。然而,人們對預訓練能帶來幫助的原因仍知之甚少。
我們可以使用Deep Bootstrap框架來研究這一點,方法是觀察上述(1)和(2)項的預訓練效果。我們發現預訓練的主要效果是改善理想世界的優化,
- 預訓練使網絡成為在線優化的“快速學習者”。
因此,在理想世界中,預訓練模型的改進泛化幾乎被其改進的優化所準確捕獲。下圖顯示了在CIFAR-10上訓練的Vision-Transformers (ViT)的情況,比較了從頭開始的訓練和在ImageNet上的預訓練。
我們也可以使用這個框架來研究數據擴充。理想世界中的數據擴充對應于將每個新樣本都會擴充一次,而不是將同一樣本擴充多次。這個框架意味著良好的數據擴充是指
- 不會顯著損害理想世界優化(即,擴充樣本看起來不會太“偏離分布”);
- 抑制真實世界優化速度(因此真實世界需要更長時間來適應其訓練集合)。
數據擴充的主要好處是通過第二項,延長了實際優化時間。至于第一項,一些激進的數據擴充(混合/剪切)實際上會損害理想世界,但這種效果與第二項相形見絀。
Deep-Bootstrap框架為深度學習中的泛化現象和經驗現象提供了一個新的視角。希望它可以應用到理解未來深度學習的其它方面。特別有趣的是,泛化可以通過純粹的優化考慮來表征,這與理論上許多流行的方法形成了對比。最關鍵的是,我們可以同時考慮在線和離線優化,它們都不是非常充分,但它們共同決定了泛化。
Deep-Bootstrap框架還可以解釋為什么Deep-learning對許多設計選擇相當魯棒:許多類型的體系結構、損失函數、優化器、規范化和激活函數都可以很好地泛化。這個框架提出了一個統一的原則:從本質上講,任何在在線優化環境下運行良好的選擇,也會在離線環境下得到很好的泛化。
最后,現代神經網絡既可以參數化過度(例如,針對小數據任務訓練的大型網絡),也可能參數化不足(例如,OpenAI的GPT-3、Google的T5或Facebook的ResNeXt WSL等等)。Deep-Bootstrap框架則表明在線優化是兩種模式成功的關鍵因素。
總結
以上是生活随笔為你收集整理的谷歌Deep Bootstrap Framework:在线优化角度理解神经网络的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 最强的Attention函数诞生啦,带给
- 下一篇: 一文弄懂各种loss function
