何恺明团队最新力作SimSiam:消除表征学习“崩溃解”,探寻对比表达学习成功之根源
該文是FAIR的陳鑫磊&何愷明大神在無監(jiān)督學(xué)習(xí)領(lǐng)域又一力作,提出了一種非常簡單的表達(dá)學(xué)習(xí)機(jī)制用于避免表達(dá)學(xué)習(xí)中的“崩潰”問題,從理論與實(shí)驗(yàn)角度證實(shí)了所提方法的有效性;與此同時(shí),還側(cè)面證實(shí)了對(duì)比學(xué)習(xí)方法成功的關(guān)鍵性因素:孿生網(wǎng)絡(luò)。
paper: https://arxiv.org/abs/2011.10566
本文為極市平臺(tái)原創(chuàng),作者Happy,轉(zhuǎn)載需獲授權(quán)。
Abstract
孿生網(wǎng)絡(luò)已成為無監(jiān)督表達(dá)學(xué)習(xí)領(lǐng)域的通用架構(gòu),現(xiàn)有方法通過最大化同一圖像的兩者增廣的相似性使其避免“崩潰解(collapsing solutions)”問題。在這篇研究中,作者提出一種驚人的實(shí)證結(jié)果:**Simple Siamese(SimSiam)**網(wǎng)絡(luò)甚至可以在無((1) negative sample pairs;(2)large batch;(3)momentum encoders)的情形下學(xué)習(xí)有意義的特征表達(dá)。
作者通過實(shí)驗(yàn)表明:對(duì)于損失與結(jié)構(gòu)而言,“崩潰解”確實(shí)存在,但是“stop-gradient”操作對(duì)于避免“崩潰解”有非常重要的作用。作者提出了一種新穎的“stop-gradient”思想并通過實(shí)驗(yàn)對(duì)其進(jìn)行了驗(yàn)證,該文所提SimSiam在ImageNet及下游任務(wù)上均取得了有競爭力的結(jié)果。作者期望:這個(gè)簡單的基準(zhǔn)方案可以驅(qū)動(dòng)更多研員重新思考無監(jiān)督表達(dá)學(xué)習(xí)中的孿生結(jié)構(gòu)。
Method
上圖給出了該文所提SimSiam的示意圖,它以圖像xxx的兩個(gè)隨機(jī)變換x1,x2x_1, x_2x1?,x2?作為輸入,通過相同的編碼網(wǎng)絡(luò)fff(它包含一個(gè)骨干網(wǎng)絡(luò)和一個(gè)投影MLP頭模塊,表示為h)提取特征并變換到高維空間。此外作者還定義了一個(gè)預(yù)測MLP頭模塊h,對(duì)其中一個(gè)分支的結(jié)果進(jìn)行變換并與另一個(gè)分支的結(jié)果進(jìn)行匹配,該過程可以描述為p1=h(f(x1)),z2=f(x2)p_1 = h(f(x_1)), z_2 = f(x_2)p1?=h(f(x1?)),z2?=f(x2?),SimSiam對(duì)上述特征進(jìn)行負(fù)cosine相似性最小化:
D(p1,z2)=?p1∥p1∥2?z2∥z2∥2\mathcal{D}(p_1, z_2) = - \frac{p_1}{\|p_1\|_2} \cdot \frac{z_2}{\|z_2\|_2} D(p1?,z2?)=?∥p1?∥2?p1???∥z2?∥2?z2??
注:上述公式等價(jià)于l2l_2l2?規(guī)范化向量的MSE損失。與此同時(shí),作者還定義了一個(gè)對(duì)稱損失:
L=12D(p1,z2)+12D(p2,z1)\mathcal{L} = \frac{1}{2}\mathcal{D}(p_1, z_2) + \frac{1}{2}\mathcal{D}(p_2, z_1) L=21?D(p1?,z2?)+21?D(p2?,z1?)
上述兩個(gè)損失作用于每一張圖像,總損失是所有圖像損失的平均,故最小的可能損失為-1.
需要的是:該文一個(gè)非常重要的概念是Stop-gradient操作(即上圖的右分支部分)。可以通過對(duì)上述公式進(jìn)行簡單的修改得到本文的損失函數(shù):
D(p1,stopgrad(zx))L=12D(p1,stopgrad(z2))+12D(p2,stopgrad(z1))\mathcal{D}(p_1, stopgrad(z_x)) \\ \mathcal{L} = \frac{1}{2}\mathcal{D}(p_1, stopgrad(z_2)) + \frac{1}{2}\mathcal{D}(p_2, stopgrad(z_1)) D(p1?,stopgrad(zx?))L=21?D(p1?,stopgrad(z2?))+21?D(p2?,stopgrad(z1?))
也就是說:在損失L\mathcal{L}L的第一項(xiàng),x2x_2x2?不會(huì)從z2z_2z2?接收梯度信息;在其第二項(xiàng),則會(huì)從p2p_2p2?接收梯度信息。
SimSiam的實(shí)現(xiàn)偽代碼如下,有沒有一種“就這么簡單”的感覺???
# Algorithm1 SimSiam Pseudocode, Pytorch-like # f: backbone + projection mlp # h: prediction mlp for x in loader: # load a minibatch x with n samplesx1, x2 = aug(x), aug(x) # random augmentationz1, z2 = f(x1), f(x2) # projections, n-by-dp1, p2 = h(z1), h(z2) # predictions, n-by-dL = D(p1, z2)/2 + D(p2, z1)/2 # lossL.backward() # back-propagateupdate(f, h) # SGD updatedef D(p, z): # negative cosine similarityz = z.detach() # stop gradientp = normalize(p, dim=1) # l2-normalizez = normalize(z, dim=1) # l2-normalize return -(p*z).sum(dim=1).mean()我們?cè)賮砜匆幌耂imSiam的基礎(chǔ)配置:
- Optimizer: SGD用于預(yù)訓(xùn)練,學(xué)習(xí)率為lr×BatchSize/256lr \times BatchSize/256lr×BatchSize/256, 基礎(chǔ)學(xué)習(xí)率為lr=0.05lr=0.05lr=0.05,學(xué)習(xí)率采用consine衰減機(jī)制,weight decay=0.0001,momentum=0.9。BatchSize默認(rèn)512,采用了SynBatchNorm。
- Projection MLP:編碼網(wǎng)絡(luò)中投影MLP部分的每個(gè)全連接層后接BN層,其輸出層fcfcfc后無ReLU,隱含層的fcfcfc的維度為2048,MLP包含三個(gè)全連接層。
- Prediction MLP:預(yù)測MLP中同樣適用了BN層,但其輸出層fcfcfc后無BN與ReLU。MLP有2個(gè)全連接層,第一個(gè)全連接層的輸入與輸出維度為2048,第二個(gè)的輸出維度為512.
- Backbone:作者選用了ResNet50作為骨干網(wǎng)絡(luò)。
作者在ImageNet上線進(jìn)行無監(jiān)督預(yù)訓(xùn)練,然后采用監(jiān)督方式凍結(jié)骨干網(wǎng)絡(luò)訓(xùn)練分類頭,最后在驗(yàn)證集上驗(yàn)證其性能。
Empirical Study
在該部分內(nèi)容中,我們將實(shí)證研究SimSiam的表現(xiàn),主要聚焦于哪些行為有助于避免“崩潰解”。
Stop-gradient
上圖給出了Stop-gradient添加與否的性能對(duì)比,注網(wǎng)絡(luò)架構(gòu)與超參保持不變,區(qū)別僅在于是否添加Stop-gradient。
上圖left表示訓(xùn)練損失,可以看到:在無Stop-gradient時(shí),優(yōu)化器迅速找了了一個(gè)退化解并達(dá)到了最小可能損失-1。為證實(shí)上述退化解是“崩潰”導(dǎo)致的,作者研究了輸出的l2l_2l2?規(guī)范化結(jié)果的標(biāo)準(zhǔn)差。如果輸出“崩潰”到了常數(shù)向量,那么其每個(gè)通道的標(biāo)準(zhǔn)差應(yīng)當(dāng)是0,見上圖middle。
作為對(duì)比,如果輸出具有零均值各項(xiàng)同性高斯分布,可以看到其標(biāo)準(zhǔn)差為1d\frac{1}{\sqrtze8trgl8bvbq}d?1?。上圖middle中的藍(lán)色曲線(即添加了Stop-gradient)接近1d\frac{1}{\sqrtze8trgl8bvbq}d?1?,這也就意味著輸出并沒有“崩潰”。
上圖right給出了KNN分類器的驗(yàn)證精度,KNN分類器可用于訓(xùn)練過程的監(jiān)控。在無Stop-gradient時(shí),其分類進(jìn)度僅有0.1%,而添加Stop-gradient后最終分類精度可達(dá)67.7%。
上述實(shí)驗(yàn)表明:“崩潰”確實(shí)存在。但“崩潰”的存在不足以說明所提方法可以避免“崩潰”,盡管上述對(duì)比中僅有“stop-gradient”的區(qū)別。
Predictor
上表給出了Predictor MLP的影響性分析,可以看到:
-
當(dāng)移除預(yù)測MLP頭模塊h(即h為恒等映射)后,該模型不再有效(work);
-
如果預(yù)測MLP頭模塊h固定為隨機(jī)初始化,該模型同樣不再有效;
-
當(dāng)預(yù)測MLP頭模塊采用常數(shù)學(xué)習(xí)率時(shí),該模型甚至可以取得比基準(zhǔn)更好的結(jié)果(多個(gè)實(shí)驗(yàn)中均有類似發(fā)現(xiàn)).
Batch Size
上表給出了Batch Size從64變換到4096過程中的精度變化,可以看到:該方法在非常大范圍的batch size下表現(xiàn)均非常好。
Batch Normalization
上表比較了投影與預(yù)測MLP中不同BN的配置對(duì)比,可以看到:
- 移除所有BN層后,盡管精度只有34.6%,但不會(huì)造成“崩潰”;這種低精度更像是優(yōu)化難問題,對(duì)隱含層添加BN后精度則提升到了67.4%;
- 在投影MLP的輸出后添加BN,精度可以進(jìn)一步提升到68.1%;
- 在預(yù)測MLP的輸出添加BN后反而導(dǎo)致訓(xùn)練變的不穩(wěn)定。
總而言之,BN有助于訓(xùn)練優(yōu)化,這與監(jiān)督學(xué)習(xí)中BN的作用類似;但并未看到BN有助于避免“崩潰”的證據(jù)。
Similarity Function
所提方法除了與cosine相似性組合表現(xiàn)好外,其與交叉熵相似組合表現(xiàn)同樣良好,見上表。此時(shí)的交叉熵相似定義如下:
D=?softmax(zx)?logsoftmax(p1)\mathcal{D} = -softmax(z_x) \cdot \text{log} softmax(p_1) D=?softmax(zx?)?logsoftmax(p1?)
可以看到:交叉熵相似性同樣可以收斂到一個(gè)合理的解并不會(huì)導(dǎo)致“崩潰”,這也就是意味著“崩潰”避免行為與cosine相似性無關(guān)。
Symmetrization
盡管前述描述中用到了對(duì)稱損失,但上表的結(jié)果表明:SimSiam的行為不依賴于對(duì)稱損失:非對(duì)稱損失同樣取得了合理的結(jié)果,而對(duì)稱損失有助于提升精度,這與“崩潰”避免無關(guān)。
Summary
通過上面的一些列消融實(shí)驗(yàn)對(duì)比分析,可以看到:SimSiam可以得到有意義的結(jié)果而不會(huì)導(dǎo)致“崩潰”。優(yōu)化器、BN、相似性函數(shù)、對(duì)稱損失可能會(huì)影響精度,但與“崩潰”避免無關(guān);對(duì)于“崩潰”避免起關(guān)鍵作用的是stop-gradient操作。
Hypothesis
接下來,我們將討論:SimSiam到底在隱式的優(yōu)化什么?并通過實(shí)驗(yàn)對(duì)其進(jìn)行驗(yàn)證。主要從定義、證明以及討論三個(gè)方面進(jìn)行介紹。
Formulation
作者假設(shè):SimSiam是類期望最大化算法的一種實(shí)現(xiàn)。它隱含的包含兩組變量,并解決兩個(gè)潛在子問題,而stop-gradient操作是引入額外變換的結(jié)果。我們考慮如下形式的損失:
L(θ,η)=Ex,τ[∥Fθ(τ(x))?ηx∥22]\mathcal{L}(\theta, \eta) = E_{x, \tau}[\|\mathcal{F}_{\theta}(\tau(x)) - \eta_x\|_2^2] L(θ,η)=Ex,τ?[∥Fθ?(τ(x))?ηx?∥22?]
其中F,τ\mathcal{F}, \tauF,τ分別表示特征提取網(wǎng)絡(luò)與數(shù)據(jù)增廣方法,x表示圖像。在這里,作者引入了另外一個(gè)變量η\etaη,其大小正比于圖像數(shù)量,直觀上來講,ηx\eta_xηx?是x的特征表達(dá)。
基于上述表述,我們考慮如下優(yōu)化問題:
minθ,ηL(θ,η)min_{\theta, \eta} \mathcal{L}(\theta, \eta) minθ,η?L(θ,η)
這種描述形式類似于k-means聚類問題,變量θ\thetaθ與聚類中心類似,是一個(gè)可學(xué)習(xí)參數(shù);變量ηx\eta_xηx?與樣本x的對(duì)應(yīng)向量(類似k-means的one-hot向量)類似:即它是x的特征表達(dá)。類似于k-means,上述問題可以通過交替方案(固定一個(gè),求解另一個(gè))進(jìn)行求解:
θt←argminθL(θ,ηt?1)ηt←argminηL(θt,η)\theta^t \leftarrow argmin_{\theta} \mathcal{L}(\theta, \eta^{t-1}) \\ \eta^t \leftarrow argmin_{\eta} \mathcal{L} (\theta^t, \eta) θt←argminθ?L(θ,ηt?1)ηt←argminη?L(θt,η)
對(duì)于θ\thetaθ的求解,可以采用SGD進(jìn)行子問題求解,此時(shí)stop-gradient是一個(gè)很自然的結(jié)果,因?yàn)樘荻认炔灰聪騻鞑サ?span id="ze8trgl8bvbq" class="katex--inline">ηt?1\eta^{t-1}ηt?1,在該子問題中,它是一個(gè)常數(shù);對(duì)于η\etaη的七屆,上述問題將轉(zhuǎn)換為:
ηxt←Eτ[Fθt(τ(x))]\eta^t_x \leftarrow E_{\tau} [\mathcal{F}_{\theta^t}(\tau(x))] ηxt?←Eτ?[Fθt?(τ(x))]
結(jié)合前述介紹,SimSiam可以視作上述求解方案的一次性交替近似。
此外需要注意:(1)上述分析并不包含預(yù)測器h;(2) 上述分析并不包含對(duì)稱損失,對(duì)稱損失并非該方法的必選項(xiàng),但有助于提升精度。
Proof of concept
作者假設(shè):SimSiam是一種類似交錯(cuò)優(yōu)化的方案,其SGD更新間隔為1。基于該假設(shè),所提方案在多步SGD更新下同樣有效。為此,作者設(shè)計(jì)了一組實(shí)驗(yàn)驗(yàn)證上述假設(shè),結(jié)果見下表。
在這里,1?step1-step1?step等價(jià)與SimSiam。可以看到:multi-step variants work well。更多步的SGD更新甚至可以取得比SimSiam更優(yōu)的結(jié)果。這就意味著:交錯(cuò)優(yōu)化是一種可行的方案,而SimSiam是其特例。
Comparison
前述內(nèi)容已經(jīng)說明了所提方法的有效性,接下來將從ImageNet以及遷移學(xué)習(xí)的角度對(duì)比一下所提方法與其他SOTA方法。
上圖給出了所提方法與其他SOTA無監(jiān)督學(xué)習(xí)方法在ImageNet的性能,可以看到:SimSiam可以取得具有競爭力的結(jié)果。在100epoch訓(xùn)練下,所提方法具有最高的精度;但更長的訓(xùn)練所得收益反而變小。
上表給出了所提方法與其他SOTA方法在遷移學(xué)習(xí)方面的性能對(duì)比。從中可以看到:SimSiam表達(dá)可以很好的遷移到ImageNet以外的任務(wù)上,遷移模型的性能極具競爭力。
最后,作者對(duì)比了所提方法與其他SOTA方法的區(qū)別&聯(lián)系所在,見上圖。
-
Relation to SimCLR:SimCLR依賴于負(fù)采樣以避免“崩潰”,SimSiam可以是作為“SimCLR without negative”。
-
Relation to SwAV:SimSiam可以視作“SwAV without online clustering”.
-
Relation to BYOL: SimSiam可以視作“BYOL without the momentum encoder”.
全文到此結(jié)束,對(duì)該文感興趣的同學(xué)建議去查看原文的實(shí)驗(yàn)結(jié)果與實(shí)驗(yàn)分析。
Conclusion
該文采通過非常簡單的設(shè)計(jì)探索了孿生網(wǎng)絡(luò),所提方法方法的有效性意味著:孿生形狀是這些表達(dá)學(xué)習(xí)方法(SimCLR, MoCo,SwAR等)成功的關(guān)鍵原因所在。孿生網(wǎng)絡(luò)天然具有建模不變性的特征,而這也是表達(dá)學(xué)習(xí)的核心所在。
相關(guān)文章
總結(jié)
以上是生活随笔為你收集整理的何恺明团队最新力作SimSiam:消除表征学习“崩溃解”,探寻对比表达学习成功之根源的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Graph Normalization
- 下一篇: Transformer再下一城!low-