【学术相关】作者解读ICML接收论文:如何使用不止一个数据集训练神经网络模型?...
作者:歐明鋒,浙江大學
導讀:在實際的深度學習項目中,難免遇到多個相似數據集,這時一次僅用單個數據集訓練模型,難免造成局限。是否存在利用多個數據集訓練的可能性?本文帶來解讀。
01?介紹
迄今為止,在深度學習領域,最流行的范式或者大家最常用的范式是端到端學習范式。
我們可以把該范式簡單概括為四個步驟:準備數據,喂入網絡數據,神經網絡優化,最后評估模型。這個范式確實也在各個領域取得了巨大成功。
然而,當我們在做一些實際的工程應用時,一項任務可能有多個相似數據集,比如在寵物分類的Dogs vs Cats, Oxford-IIIT Pet數據集,交通車輛檢測的BDD100k,KITTI-object等數據集。通常的做法是一次僅選擇其中的一個進行各種模型訓練,這不僅浪費了其他的數據集,也同時給模型帶來局限。
因此,我們可能會問這樣一個問題:為什么只使用一個數據集來訓練神經網絡模型?
這是我在Graviti作為算法實習生,與leader以及導師一起完成的一項研究工作,已經被ICML2021接受了,非常感謝Datawhale給我向大家分享論文。今天的分享簡單分為?介紹(包括movivation,related work等等),方法,實驗驗證,最后的結論 四個部分。
回到正題,針對上面的問題,那肯定要利用起多個數據集的。
有些數據集可以輕松融合在一起,因為他們有重疊的標簽,就像下面這兩個traffic相關的數據集有共同的標簽類 person和bike, 但有些不能,我們認為其中一個主要的瓶頸之一是標簽差異,標簽集存在不同的語義層次或粒度。
就像這里底部寵物數據集的例子,數據集a標簽是貓狗等,數據集b標簽是一些貓狗的品種如布偶貓,薩摩耶等,因為兩個數據集的標簽粒度存在差異,導致其無法直接融合。
事實上,確實有些前人的工作涉及該方面, 我將這些工作主要分為了兩類:1.是左邊的直接融合,直接在標簽空間進行,這要求標簽的一致性,這通??梢酝ㄟ^偽標簽的方式進行;2.是右邊的間接融合,它可以抽象為通過共享的隱藏向量空間進行數據集融合,相應的算法框架涉及遷移學習、領域自適應等。
而我們的思路是從數據集的語義信息角度出發, 由于具有相似目的的數據集其標簽在領域知識是具有的語義關聯,所以我們就通過構造一個統一的知識驅動的標簽圖來在標簽空間中直接進行數據集融合。
這里舉了個具體的例子,左邊的部分是動物領域的三個相似的數據集及其標簽集,由于這些標簽集之間的語義層次和粒度不同,它們無法輕松融合。然而,在通過標簽集之間的語義關系建立標簽圖之后,這些數據集成功地連接起來,三個數據集就被組合成一個單一的數據集。
更具體地來說,左邊是傳統的未融合數據集的示例,幾個相似的數據集,但標簽集之間存在差異,每個數據集對應一個單標簽預測模型的訓練過程。右邊我們提出的方法,我們將這些數據集連接在一起,驅動模型預測 標簽圖上以目標節點為終點的整個軌跡,而不是單一的標簽預測。
我們模型的基本架構就是特征提取網絡接上序列生成網絡,即Encoder-Decoder的結構。
介紹部分就到這里,接下來是方法部分。
02 方法
首先是圖譜構建的流程,這里其實是展示了一個抽象化的流程。這里假設對兩個數據集的標簽來構建圖譜, 這兩個數據集分別假設為:
貓狗二分類數據集
貓狗的細粒度品種分類數據集
構建步驟抽象為以下四個步驟, 1.首先是添加根節點,就是黃色的動物節點;2. 所有數據集的標簽節點,就是綠色的節點;3. 以及代表屬性特征的擴展節點,即藍色的節點;4. 最后連接邊。
但實際上這個圖的構建過程是更為具體和直接的,因為這個圖其實不是我們構造的,而是通過 “竊取”來的。因為這個標簽圖本質上是從相關的領域幾十年來積累的領域知識中獲得的。
以貓的品種分類為例:
首先,我們將cat設置為根節點,接著我們從Purina這樣的領域網站上發現了三種類型的coat特性。因此,我們添加它們作為增強節點來表示貓的一方面外觀特征;其次,我們check了coat field中的對應框“Short”,發現了許多短毛品種,并將它們放置在增強節點shorthair下。通過類似的方式,就可以構建出一張很大的或者說完整的標簽圖。
同時在剛剛的這個過程中,我們很容易發現,構造過程類似于人類在執行分類時的決策方式。當我們人看到一種動物時,我們首先根據它的全局特征來判斷它的大致類別,然后仔細觀察它的局部特征來確定它細分的品種。
也就是說在我們的方法中,模型在執行推理時,標簽圖其實提供了一個“決策過程”。
此外,我們認為這種方法是象征主義和連接主義的結合。也就是說,我們將幾十年積累起來的領域知識歸納為一個深度神經網絡模型。
為了更好地捕捉下方標簽圖上同一層級節點間的關系,我們定義了競爭節點的概念。
定義u和w是競爭節點,當且僅當u和w有著共同的祖先節點,并且它們在分類法上是互斥的。
針對競爭節點,我們提出了block-softmax;因為對于一般softmax,所有類別都在相互競爭。但是,在我們的體系結構中,競爭關系僅存在于競爭節點之間。因此做了一個block的限制,從而將相對概率的計算限制到了每個競爭節點組內。右圖就是一個對比示意圖:
說完節點來到路徑,我們也定義了確定性和不確定性路徑來分別處理 類別具有確定性以及不確定特征 的情況。首先是確定性路徑,它的定義如這里所示,比較抽象,我們就直接來看一個具體的例子:
給定標簽節點v和經過該節點的路徑P(v),如果不存其他路徑P′(v)滿足條件:? u∈P(v),w∈P^′(v), u,w形成競爭節點并且u ≠w 則P(v)是確定性路徑。
右圖中的一個例子就是動物-》貓-〉短毛->英國短毛貓, 之所以說這條路徑是確定的是因為,所有的英國短毛貓都是短毛的。
首先是確定性路徑的訓練,我們采用了Teacher forcing的訓練策略, 該流程如右圖所示,對于確定性ground truth路徑P,我們將其視為一個序列,讓循環單元自回歸地預測序列上的每個節點, 然后我們就能得到如下的損失函數,(本質上就是最大化整條正確路徑的概率),從而反向傳播并優化。
然后是關于非確定性路徑。給定路徑錨定(anchoring)標簽節點,,如果存一條其他路徑滿足條件:,,,形成競爭節點并且 ,則是非確定性路徑。
右圖中有三條不確定性路徑,被標記為紅色。因為英國短發貓的毛色模式可以是純色、重點色、虎斑色中的任意一種。因此,經過這三個節點到英國短毛節點的路徑都是不確定的。
由于其路徑中的不確定節點導致teacher forcing策略無法正常使用,所以我們采用了Reinforce算法。首先我們定義了一個激勵函數,即“模型采樣的生成路徑”和“ground truth標簽節點集”之間交集的歸一化大小。進而定義出了損失函數,其實本質上就是最大化采樣生成路徑的期望獎勵,能夠通過最后一個式子估計出不確定性路徑的梯度,具體的推導請參考reinforce的論文。
然后我們最終的訓練策略的話其實就是在一個batch中依次進行確定性和非確定性路徑的訓練,具體詳細的訓練流程就不在這里說了,有興趣的可以看一下我們論文中的偽代碼。
03 實驗
實驗部分我們分別在單標簽圖像和文本分類任務上進行的。
首先,關于數據集設置,分為三組:
第一組是關于寵物分類,第二組是關于花分類, 第三組是對arxiv文章進行學科分類,arxiv學科的標簽其實是有層級的,比如第一級cs,第二級 ml,arxiv augment就只保留了其最高層級的標簽。
前兩組的標簽圖都是我們通過現有的領域知識構建的,arxiv那一組標簽其實是有層級的,比如第一級cs,第二級 ml,就直接將層級關系展開為標簽圖。
組1和組3對應于細粒度和粗粒度數據集的融合,并且數據集之間沒有標簽重疊, 組2對應于在相同粒度級別上標注的兩個數據集的融合,其中重疊標簽數量為8
出于評估目的,我們的測試都是在難度更大的細粒度數據集上進行的:
然后,是關于模型的設置的。
首先是baseline, 在圖像分類中,有三種。1.傳統的單標簽預測模型 2.基于偽標簽的融合數據集,即為粗數據集中的樣本生成細粒度偽標簽,并將這些樣本合并到細粒度數據集中。3.它是一個多標簽分類設置,采用了之前工作中的一個關鍵實驗。而在文本分類任務中,基線是傳統的單標簽預測模型。
然后是我們的模型。其中對于Encoder,圖像分類任務中使用EfficientNet-b4而文本分類任務使用Bert或LSTM作為特征提取器,對于Decoder使用GRU, 并且在圖像分類任務中融合了注意力模塊來幫助GRU單元在不同的step關注到圖像中不同位置的信息。
然后是實驗的主要結果。從表中可以看出兩點:
1.如紅色虛線框中對比數據所示,即使沒有額外數據集的幫助,簡單地將標簽擴展為標簽關系圖,再加上我們的訓練策略,表現仍然會有所提升。因為將標簽擴展為標簽關系圖,其實本質上就是一種數據增強的方式,只是與傳統的數據增強方法集中于數據本身上不同,本文增強了標簽之間的關系,或者另一種角度來看本文為每個標簽的樣本又引入了額外的標簽,即額外的監督信息。
2.如綠色虛線框中的對比數據所示,使用本文所提出的方法要優于直接融合,以及基于偽標簽融合的方法,同時也要優于傳統的單標簽預測模型,說明了我們方法在標簽空間進行數據集融合的可行性。
更重要的是,我們的方法具有增強的可解釋性。為了說明這一點,我們以波斯貓為例,波斯貓用紅色虛線橢圓標記,波斯貓的毛色模式是重點色或純色,這是不確定的。該模型通過確定性的重點色和純色的貓類樣本來學習這兩種顏色模式的特征,應用在不確定性路徑樣本的推理上,從而區分波斯貓中不同毛色模式的樣本。這就像之前說的,我們的標簽圖其實就是為我們的模型在推理時提供了決策過程的過程,從而使其更具有可解釋性。實驗部分到此結束。
04 結論
在這項工作中,我們研究了數據集連接的問題,更具體地說是在標簽系統不一致時的標簽集連接問題。我們提出了一個新的框架來解決這個問題,包括標簽空間擴充、遞歸神經網絡、序列訓練和策略梯度。經過訓練的模型在性能和可解釋性方面都顯示出良好的結果。
當然這項工作只是一個多數據集連接初步的探索, 其中還有很多問題可以研究解決,包括以下:
圖譜質量的如何衡量,
如何構建更加魯棒的方法來適應的有噪聲標簽關系圖,
融合后數據集產生的分布偏移問題該如何解決,
同時直接還有很多可擴展的方向,包括:
偽標簽方法相結合
在其他任務如目標檢測、分割上進行探索
以上的話就是對我們這項工作的整體介紹,關于該項工作的更多細節可以去arxiv上看看我們的paper。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載黃海廣老師《機器學習課程》視頻課黃海廣老師《機器學習課程》711頁完整版課件本站qq群554839127,加入微信群請掃碼:
總結
以上是生活随笔為你收集整理的【学术相关】作者解读ICML接收论文:如何使用不止一个数据集训练神经网络模型?...的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 优酷视频如何将地区设置为中国大陆
- 下一篇: 优酷视频如何进行连续播放?