度量学习和pytorch-metric-learning的使用
??度量學(xué)習(xí)是學(xué)習(xí)一種特征空間的映射,把特征映射到具有度量屬性的空間中,所謂度量屬性是指在某種度量距離(可以是歐氏距離、余弦相似性等)下類(lèi)內(nèi)距離更小,類(lèi)間距離更大。有了這種屬性之后,就可以?xún)H根據(jù)特征間的距離來(lái)判斷樣本是否屬于同一類(lèi),常用在少樣本學(xué)習(xí)任務(wù)中,解決由于樣本數(shù)量少而無(wú)法或不足以建立從特征到類(lèi)別的參數(shù)化映射的問(wèn)題。有一個(gè)開(kāi)源的度量學(xué)習(xí)庫(kù)pytorch-metric-learning,集成了當(dāng)前常用的各種度量學(xué)習(xí)方法,是一個(gè)非常好用的工具。
??度量學(xué)習(xí)作為一個(gè)大領(lǐng)域,網(wǎng)上有不少介紹的文章,pytorch-metric-learning庫(kù)的官方文檔也有比較詳細(xì)的說(shuō)明和demo,所以本文不打算再對(duì)它們做細(xì)致嚴(yán)謹(jǐn)?shù)娜腴T(mén)級(jí)介紹,而是記錄我在學(xué)習(xí)過(guò)程中的一些思考和觀點(diǎn),以及代碼的便捷使用,歡迎討論。
1 度量學(xué)習(xí)的主要原理
??如前所述,度量學(xué)習(xí)的目的是把特征約束到具有度量屬性的空間中。
??首先什么是特征?通常我們認(rèn)為,深度學(xué)習(xí)網(wǎng)絡(luò)中(如ResNet50),前面的卷積部分作用是逐層深入的提取越來(lái)越高級(jí)的特征,最后的全連接層的作用則是建立特征到具體類(lèi)別間的規(guī)則關(guān)系。那么我們度量學(xué)習(xí)要處理的特征當(dāng)然應(yīng)該選擇經(jīng)卷積層完全處理完畢,輸入給全連接網(wǎng)絡(luò)的這層特征,如果這層特征的維度太高而導(dǎo)致運(yùn)算量太大(如ResNet50的卷積層最后輸出有2048通道),也可以先用全連接層降維。我們把這層特征提取出來(lái),通常稱(chēng)為嵌入特征embbeding features。
??然后是如何約束?神經(jīng)網(wǎng)絡(luò)通過(guò)加入損失函數(shù)作為約束條件,以類(lèi)別標(biāo)簽作為監(jiān)督信息,使用監(jiān)督學(xué)習(xí)的方法訓(xùn)練網(wǎng)絡(luò)參數(shù),使網(wǎng)絡(luò)輸出的嵌入特征逐漸滿(mǎn)足約束條件。所謂約束條件就是指類(lèi)內(nèi)距離小、類(lèi)間距離大,那么損失函數(shù)就是要滿(mǎn)足這兩個(gè)目標(biāo),但具體設(shè)計(jì)起來(lái)仍有很多技術(shù)在里面,主要是怎么讓訓(xùn)練收斂更快、更穩(wěn)定,收斂的結(jié)果更好。
2 幾種損失函數(shù)
2.1 Contrastive Loss
??只考慮兩兩之間的類(lèi)別和距離。注意其中加入margin的思路(即max(0,margin-loss)的方式),我覺(jué)得這個(gè)思路很有意思,值得學(xué)習(xí)。我理解加入margin的目的是這樣的:因?yàn)槲覀儍?yōu)化類(lèi)間距離的目的是越大越好,而通常損失函數(shù)在梯度下降算法中是趨向于越來(lái)越小的,所以必須讓損失函數(shù)也變成一個(gè)減函數(shù),如果我們直接取一個(gè)負(fù)號(hào),則會(huì)讓損失函數(shù)變成一個(gè)負(fù)的特別大的數(shù),這通常會(huì)導(dǎo)致訓(xùn)練不穩(wěn)定。而且我們也不需要類(lèi)間距離非常大,只要大于一定值能夠和類(lèi)內(nèi)距離明顯區(qū)分就可以了,所以這里加入一個(gè)margin,讓目標(biāo)函數(shù)限定到(0,margin)區(qū)間內(nèi),避免訓(xùn)練不穩(wěn)定。在網(wǎng)絡(luò)設(shè)計(jì)中我們經(jīng)常也會(huì)遇到想讓一個(gè)目標(biāo)函數(shù)越來(lái)越大的情況,我覺(jué)得可以嘗試這種取負(fù)號(hào)再加margin的方案,當(dāng)然另一種思路是取倒數(shù),即1/loss,我沒(méi)對(duì)比過(guò)哪個(gè)更有效,以后有機(jī)會(huì)可以試一試。
2.2 Triplet Loss
??考慮三元組,錨樣本、正樣本和負(fù)樣本之間的互相距離。由于每次同時(shí)考慮了正樣本對(duì)和負(fù)樣本對(duì)的距離都滿(mǎn)足約束關(guān)系,訓(xùn)練效率比Contrastive loss要高。注意Triplet loss的目的是讓類(lèi)間距離比類(lèi)內(nèi)距離更大,對(duì)應(yīng)的也是一個(gè)越大越好的目標(biāo)函數(shù),這里也使用了取負(fù)號(hào)再加margin并截?cái)嘭?fù)值的思路把它變成一個(gè)區(qū)間減函數(shù)形式的損失函數(shù)。
2.3 更多損失函數(shù)
??在Triplet loss之后,又發(fā)展出來(lái)了各種各樣更多的損失函數(shù),如N-pair Loss, ranklist loss, multisimilarity loss ,cirlce loss等,總的趨勢(shì)是考慮更多的對(duì)之間的距離關(guān)系,并考慮各樣本學(xué)習(xí)的難易不同,提高難例的權(quán)重。circle loss還考慮了類(lèi)別標(biāo)簽的使用,也就是綜合了分類(lèi)損失。
2.4 在pytorch-metric-learning中使用各種損失函數(shù)
??使用pip install pytorch-metric-learning安裝該庫(kù),支持的多種損失函數(shù)可查閱pytorch-metric-learning的官方文檔。如果對(duì)各種損失函數(shù)不太了解,直接使用Circle loss就好了。
from pytorch_metric_learning import losses loss_func = losses.CircleLoss() for data, labels in train_loader:embeddings,_ = model(data)loss = loss_func(embeddings, labels)??注意默認(rèn)的度量距離使用余弦相似性,如不特殊指定的話,也應(yīng)在推理階段使用余弦相似性對(duì)生成的嵌入特征進(jìn)行處理。
3 難例挖掘方法
??我們可以看到,在當(dāng)前度量學(xué)習(xí)的損失函數(shù)設(shè)計(jì)中都是含有margin的,那么網(wǎng)絡(luò)在學(xué)習(xí)的時(shí)候?qū)τ谝桌?#xff0c;損失函數(shù)已經(jīng)能夠達(dá)到margin的程度,梯度會(huì)變成0,就不會(huì)再對(duì)網(wǎng)絡(luò)有任何訓(xùn)練的作用了,只有那些難例,梯度較大,才會(huì)對(duì)網(wǎng)絡(luò)起到較大的訓(xùn)練作用,所以為了提高訓(xùn)練效率,需要只把難例加入訓(xùn)練。pytorch-metric-learning庫(kù)提供的挖掘方法在miners文件中,有很多,具體可以查看官方文檔。我在未經(jīng)充分試驗(yàn)的情況下發(fā)現(xiàn)MultiSimilarityMiner效果不錯(cuò),如果對(duì)各種挖掘方法不是很了解,可以先用這個(gè)。加入之后的訓(xùn)練方法如下:
from pytorch_metric_learning import losses, miners loss_func = losses.CircleLoss() mining_func = miners.MultiSimilarityMiner() for data, labels in train_loader:embeddings,_ = model(data)hard_tuples = mining_func(embeddings, labels)loss = loss_func(embeddings, labels, hard_tuples)4 采樣器設(shè)計(jì)
??如果訓(xùn)練樣本的類(lèi)別較多的時(shí)候,隨機(jī)采樣的話可能在一個(gè)mini batch內(nèi)遇不到幾個(gè)正樣本對(duì),例如有100類(lèi),而batch size只有64。而損失函數(shù)設(shè)計(jì)都是考慮到正樣本對(duì)的,這會(huì)導(dǎo)致訓(xùn)練效果受很大影響,所以要調(diào)整采樣器,保證一個(gè)mini batch內(nèi)有一定比例的正樣本對(duì),比如可以指定每一類(lèi)都固定采樣m個(gè)樣本,通常m=4效果較好,這樣batch size = 64時(shí),每次能夠采樣16類(lèi),每類(lèi)4個(gè)樣本。pytorch-metric-learning庫(kù)中集成了一個(gè)采樣器MPerClassSampler,它對(duì)torch.utils.data.sampler修改而來(lái),調(diào)用方法如下:
from pytorch_metric_learning import samplers sampler = samplers.MPerClassSampler(labels, m=4,length_before_new_iter=len(train_dataset)) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,sampler=sampler, **kwargs)??還有其他一些采樣器可參考官方文檔,如果還需要給采樣器加入更多的功能,比如在MPerClassSampler的基礎(chǔ)上還想保證訓(xùn)練的時(shí)候各類(lèi)別平衡采樣以解決長(zhǎng)尾問(wèn)題,可以再魔改一下官方MPerClassSampler代碼文件,也不難。
5 度量-分類(lèi)聯(lián)合訓(xùn)練
??度量學(xué)習(xí)只考慮一對(duì)數(shù)據(jù)是否是同類(lèi)或異類(lèi),并不考慮每個(gè)數(shù)據(jù)具體是哪一類(lèi),顯然它沒(méi)有充分利用標(biāo)簽提供的信息。所以如果只用度量學(xué)習(xí)的損失函數(shù)進(jìn)行訓(xùn)練,網(wǎng)絡(luò)提取的特征不充分,如果同時(shí)加入全連接層構(gòu)成的分類(lèi)器,并使用交叉熵?fù)p失接受標(biāo)簽的分類(lèi)監(jiān)督信息,可以提高網(wǎng)絡(luò)的特征提取能力。(circle loss可以部分彌補(bǔ)這一點(diǎn),但仍沒(méi)有加入分類(lèi)聯(lián)合訓(xùn)練效果好)
from pytorch import nn from pytorch_metric_learning import losses, miners loss_func = losses.CircleLoss() mining_func = miners.MultiSimilarityMiner() criterion = nn.CrossEntropyLoss() for data, labels in train_loader:embeddings,out = model(data)hard_tuples = mining_func(embeddings, labels)loss1 = loss_func(embeddings, labels, hard_tuples)loss2 = criterion(out,labels)loss = 0.1*loss1 + loss2??代碼中的embeddings指嵌入層特征,out指網(wǎng)絡(luò)最后和類(lèi)別數(shù)同維度的輸出。
6 度量學(xué)習(xí)的思考
1,為什么度量學(xué)習(xí)能夠泛化?它有多強(qiáng)的泛化能力?
??度量學(xué)習(xí)常被用在少樣本學(xué)習(xí)領(lǐng)域,因?yàn)樗谝阎?lèi)別上訓(xùn)練完成的度量空間映射能力同樣可以在未知類(lèi)別上使用。這是因?yàn)槎攘繉W(xué)習(xí)通過(guò)同類(lèi)和異類(lèi)的對(duì)比,把數(shù)據(jù)集中共性特征放大,個(gè)性特征抑制,而分類(lèi)學(xué)習(xí)更關(guān)注對(duì)每類(lèi)的個(gè)性特征,所以相比度量學(xué)習(xí)有更好的泛化能力。而如果測(cè)試集和訓(xùn)練集不僅存在類(lèi)別差異,還存在特征分布差異(域偏差)的時(shí)候,度量學(xué)習(xí)也不能夠泛化。
2,度量學(xué)習(xí)在推理階段僅使用樣本間的距離進(jìn)行判斷,顯然浪費(fèi)了特征中蘊(yùn)含的很多信息量,并不是一種最優(yōu)方案。
??更理想的方案應(yīng)該是并不對(duì)特征壓縮轉(zhuǎn)換為單維度的標(biāo)量(如距離),仍使用特征豐富的全維度信息進(jìn)行推理。比如可以使用一個(gè)fc層或多fc層的MLP網(wǎng)絡(luò)來(lái)實(shí)現(xiàn)對(duì)特征到某個(gè)單類(lèi)的規(guī)則映射,這個(gè)fc層的參數(shù)不使用梯度下降訓(xùn)練法得到(只有少樣本也很難訓(xùn)練),而是解方程得到一個(gè)最小二乘解,這個(gè)方法也許會(huì)有更好的效果,先放到這里,有待后續(xù)研究。
總結(jié)
以上是生活随笔為你收集整理的度量学习和pytorch-metric-learning的使用的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【论文学习】RepVGG: Making
- 下一篇: 深度学习主机环境配置: Win10+Nv