ICLR 2022 | 从因果不变性视角探讨图神经网络的分布外泛化鲁棒性
?作者 |?吳齊天
單位 |?上海交通大學
研究方向 |?圖神經(jīng)網(wǎng)絡
論文題目:
Handling Distribution Shifts on Graphs: An Invariance Perspective
作者信息:
吳齊天(上海交通大學),張恒瑞(伊利諾伊大學),嚴駿馳(上海交通大學),David Wipf(Amazon Web Service)
論文鏈接:
https://arxiv.org/pdf/2202.02466.pdf
如何提高在新數(shù)據(jù)上的泛化性能是機器學習的一個核心問題。然而,近年來很多研究表明神經(jīng)網(wǎng)絡對數(shù)據(jù)的分布異常敏感,當測試數(shù)據(jù)的分布與訓練數(shù)據(jù)呈現(xiàn)明顯不同時,模型的泛化性能將受到很大的影響。這也為深度學習的實際應用與落地帶來了隱憂與困難,特別是針對一些高風險的領域,如自動駕駛、金融投資、醫(yī)療診斷、刑事司法等。
目前大部分關(guān)于分布外泛化問題的研究集中在歐式數(shù)據(jù)(如圖像、文本等),而對于圖結(jié)構(gòu)數(shù)據(jù)的相關(guān)研究還較少。與普通歐式數(shù)據(jù)不同的是,圖結(jié)構(gòu)數(shù)據(jù)上的分布偏移問題需要解決不同的技術(shù)挑戰(zhàn)。首先,由于節(jié)點的互連特性,數(shù)據(jù)樣本通常是非獨立同分布的,這就為數(shù)據(jù)生成分布的建模帶來了困難。其次,除了節(jié)點特征外,圖的結(jié)構(gòu)也蘊含了重要的信息,會影響到表示學習和預測任務。
本文介紹被International Conference on Learning Representations (ICLR‘22) 會議接收的一項新工作,關(guān)注圖結(jié)構(gòu)數(shù)據(jù)的分布偏移問題,主要貢獻如下:
對圖上的分布外泛化問題給出了形式化定義,并提供了基于因果不變性假設的分析視角。
從理論上證明了傳統(tǒng)學習方法無法實現(xiàn)有效的分布外泛化,并提出了一種新的目標函數(shù)(探索-外推風險最小化),用于實現(xiàn)從有限的觀測數(shù)據(jù)向測試分布的外推。
通過理論分析表明新的目標函數(shù)可以有效解決分布外泛化問題,并且訓練過程可以有效降低測試數(shù)據(jù)上的泛化誤差。
為了驗證提出的方法,考慮了三個不同的場景(處理人造混淆噪聲、跨圖遷移、動態(tài)圖時序外推),并在多個不同的GNN主干模型上展示了方法的有效性和穩(wěn)健性。
背景:圖上的分布偏移與分布外泛化
宏觀層面的定義 我們假設輸入數(shù)據(jù)是一個圖,它包含了兩部分信息: 輸入鄰接矩陣和節(jié)點特征。這里表示節(jié)點的集合。此外,每個節(jié)點對應一個標簽,所有節(jié)點的標簽組成了一個向量。我們定義表示輸入圖的隨機變量(是隨機變量的一個具體實現(xiàn)),而是標簽向量的隨機變量(同理是一個具體實現(xiàn))。此外,我們引入一個環(huán)境變量,它表示與數(shù)據(jù)生成相關(guān)聯(lián)的某種上下文信息(未被觀測到)。于是,圖數(shù)據(jù)的生成過程可以由聯(lián)合分布的展開進行描述
然而上述的定義方式不方便對圖上的分布外泛化問題進行分析和求解,特別是考慮到圖上的節(jié)點級任務(此時輸入數(shù)據(jù)通常只有一張圖或極少量的圖),因此我們考慮一種微觀層面的定義。
微觀層面的定義 將輸入的圖以節(jié)點為單位(通常每個節(jié)點就是一個訓練/測試樣本)分解為一系列子圖。具體的,假設表示節(jié)點的隨機變量,定義節(jié)點的階鄰居內(nèi)的節(jié)點集合為(這里是任意的正整數(shù))。中的節(jié)點形成了一個子圖,它包含了一個(局部)節(jié)點特征矩陣和一個(局部)鄰接矩陣。同樣定義為子圖的隨機變量而是其具體實現(xiàn)。定義是節(jié)點標簽的隨機變量,對應具體的實現(xiàn)。由此,我們將輸入圖分解為一系列子圖的集合,這里我們可以將視為模型(例如圖神經(jīng)網(wǎng)絡)的輸入,是輸出。注意到,這里可以視為是節(jié)點的馬爾可夫毯(Markov Blanket),可以被分解為個獨立相同的分布的乘積,即。
基于上述定義,我們可以把觀測數(shù)據(jù)從數(shù)據(jù)生成分布的采樣生成過程考慮成兩步:1)首先采樣一個完整的輸入圖,而它可以被視作一系列子圖的集合; 2)接著對圖上的每一個單一節(jié)點,采樣其標簽。下面我們給出圖上的分布外泛化問題的數(shù)學定義。
分布外泛化問題的形式化定義:給定訓練數(shù)據(jù)(其數(shù)據(jù)分布為),模型的目標是最終泛化到新的測試數(shù)據(jù)(其數(shù)據(jù)分布為)。我們定義表示環(huán)境變量的取值集合,是模型的預測函數(shù)即,是損失函數(shù)。于是,我們的優(yōu)化目標可以寫為
402 Payment Required
當然,這里的第一步采樣可以被省略,因為大部分圖上的問題(如節(jié)點分類)通常假設只有一個輸入圖(包含成千上萬的節(jié)點)。
方法:基于因果不變性的風險外推
直接解決上述的問題是非常困難的,因為模型在沒有結(jié)構(gòu)性假設和對學習任務的先驗知識的情況下往往是不可能實現(xiàn)分布外泛化的。為此,本文從數(shù)據(jù)生成的角度,通過利用數(shù)據(jù)背后的因果不變性[1,2,3],來引導模型學習到可以實現(xiàn)泛化的映射關(guān)系。
在進入技術(shù)細節(jié)之前,我們首先考慮一個具體的例子作為熱身。我們考慮一個引用網(wǎng)絡,每個節(jié)點表示一篇論文,每條連邊表示論文之間的引用關(guān)系。每個節(jié)點有兩個特征——論文發(fā)表的會議與論文的影響力,標簽是論文的主題,環(huán)境是論文發(fā)表的時間。我們可以將上述變量的因果關(guān)系表示為下圖:
具體的,對圖中的三個因果依賴關(guān)系可以作如下理解。1):論文發(fā)表的會議會決定論文研究的主題;2):論文的影響力往往與論文的主題有關(guān);3):論文的影響力還與論文發(fā)表的時間有關(guān)(研究方向的流行度會隨時間變化)。
在這個例子中,會同時與和有關(guān)。也就是說,當環(huán)境發(fā)生變化時(對應于數(shù)據(jù)采樣的分布發(fā)生了變化),與之間的關(guān)系也會發(fā)生變化。因此,如果模型在訓練集上學習到了這部分關(guān)聯(lián)性,當遷移到測試集后就不能獲得令人滿意的結(jié)果(因為環(huán)境的改變導致了與關(guān)系的改變)。相反的,如果模型在訓練集中學習到了與的關(guān)系,就能夠成功遷移到測試集(因為就算環(huán)境發(fā)生了改變,與之間的關(guān)系是穩(wěn)定不變的)。
傳統(tǒng)學習方法的局限性 我們接著把上述例子進行一般化推廣,假設數(shù)據(jù)間的關(guān)系為線性且圖神經(jīng)網(wǎng)絡只考慮一階鄰居。也就是說,(以及)只包含圖上的一階鄰居。此外,假設每個節(jié)點有兩個特征,,數(shù)據(jù)生成的過程假設如下:
402 Payment Required
這里和表示由獨立的標準正態(tài)分布產(chǎn)生的噪聲,是一個均值為零,方差大于零且與環(huán)境有關(guān)的隨機變量。上面提到的引用網(wǎng)絡的例子正是這個假設場景的具體體現(xiàn)。此外,我們考慮用于預測的圖神經(jīng)網(wǎng)絡模型
402 Payment Required
,其中都是需要學習的模型參數(shù)。可以看到,理想的最優(yōu)模型參數(shù)是,此時GNN只利用了,即不受環(huán)境影響的特征。但是,我們可以證明,當直接使用傳統(tǒng)的經(jīng)驗風險損失(Empirical Risk Minimization,即直接優(yōu)化訓練數(shù)據(jù)的損失)作為訓練目標時,模型無法實現(xiàn)分布外泛化(一定會學習到與的關(guān)系)。
命題1.假設定義在與環(huán)境有關(guān)的數(shù)據(jù)上的誤差損失為
402 Payment Required
, 則當采用優(yōu)化目標,模型訓練得到的參數(shù)為,這里表示在不同環(huán)境下的方差。這一結(jié)論表明,當我們直接使用傳統(tǒng)的學習目標進行訓練時,模型必然會學習到數(shù)據(jù)中對環(huán)境敏感的特征,從而無法實現(xiàn)理想的分布外泛化。與此同時,我們可以得到另一個結(jié)論。
命題2.當采用優(yōu)化目標時,模型對應的唯一最優(yōu)解為。
這表明,我們可以同時最小化在不同環(huán)境的數(shù)據(jù)上的平均損失以及損失的方差,這就可以幫助模型學習到與環(huán)境無關(guān)的從輸入特征到輸出的關(guān)系,實現(xiàn)分布外泛化。
基于以上的探索和洞察,對于訓練數(shù)據(jù),我們可以考慮一個預測模型,將學習目標定義為在不同環(huán)境上對應風險損失的均值和方差:
這里定義,是一個權(quán)重超參數(shù)。然而,上式則要求訓練數(shù)據(jù)中包含來自多個環(huán)境的觀測數(shù)據(jù),并且數(shù)據(jù)樣本與環(huán)境的對應關(guān)系也是已知的。對于圖結(jié)構(gòu)數(shù)據(jù),尤其是節(jié)點級任務,這兩個要求都是不滿足的。通常情況下,訓練數(shù)據(jù)只包含了一整張大圖,即可以認為數(shù)據(jù)樣本都來自同一環(huán)境。為了解決這一困難,我們引入個額外的數(shù)據(jù)生成器(),基于輸入圖生成份不同的圖數(shù)據(jù) 來探索環(huán)境,模擬來自不同環(huán)境的觀測數(shù)據(jù)。基于此,我們考慮如下的雙層優(yōu)化學習目標:
這里我們定義每個圖數(shù)據(jù)所對應的損失函數(shù)
402 Payment Required
。針對數(shù)據(jù)生成器,我們將其參數(shù)化為一個圖結(jié)構(gòu)學習器(graph editor),即將每一條連邊假設為自由參數(shù),對輸入圖進行局部改變(刪除或增加連邊)。具體的,我們將每一個改變視為動作(action),最終使用基于策略梯度的REINFORCE算法進行優(yōu)化,以解決離散動作空間采樣不可導的問題。我們將本文提出的方法稱為Explore-to-Extrapolation Risk Minimization(EERM),下圖給出了訓練過程的數(shù)據(jù)流圖。
理論分析
為了證明提出方法的有效性,下面我們進行理論分析,討論學習目標產(chǎn)生的最優(yōu)解與分布外泛化目標之間的關(guān)系,以及在訓練數(shù)據(jù)上得到的模型在測試集上的泛化誤差。為了提升閱讀效率,在本節(jié)中我們使用概括性的語言取代部分數(shù)學符號和公式,具體的內(nèi)容和證明請參見論文原文。首先,我們引入兩個假設條件。
假設1.對于輸入圖數(shù)據(jù)中的任意子圖,存在一個非線性映射由它給出的特征滿足1) (不變性條件): ,以及2) (充分性條件): ,其中是一個非線性函數(shù), 是獨立的隨機噪聲。
假設2.對于滿足假設1的數(shù)據(jù),存在一個隨機變量使得。我們假設在不同的環(huán)境下可以任意變化。
直觀上講,假設1保證了輸入數(shù)據(jù)中存在一部分滿足理想條件的預測信息,這部分信息可以充分預測標簽而且這一關(guān)聯(lián)對于不同環(huán)境是不變的。假設2保證了輸入數(shù)據(jù)中還存在一部分對環(huán)境敏感的信息,隨著環(huán)境的改變,它們與標簽的關(guān)聯(lián)性可以任意的變化。
定理1.在假設1與假設2的條件下,如果GNN模型產(chǎn)生的分布滿足:1)不變性條件,即互信息?,以及2)充分性條件,即被最大化,則學習目標(2)產(chǎn)生的預測函數(shù)對應(1)定義的分布外泛化問題的最優(yōu)解。
這一結(jié)論表明,本文提出的方法可以在理論上保證取得理想的分布外泛化問題的最優(yōu)解。此外,我們還可以從信息論的視角對測試分布上的泛化誤差給出相應的分析。
定理2.只要模型輸入與輸出關(guān)于節(jié)點表示的條件互信息在訓練集和測試集上是相等的,即,則對公式(2) 在訓練數(shù)據(jù)上進行優(yōu)化時,會降低測試集上預測分布與真實數(shù)據(jù)分布的距離
402 Payment Required
的上界。這一結(jié)論表明,只要模型給出的節(jié)點表示在訓練集和測試集上具有相同的表達能力(具體量化為輸入與輸出包含在表示向量中的信息),本文的優(yōu)化目標可以降低測試分布上的泛化誤差上界。這進一步從理論上驗證了提出方法的有效性。
實驗結(jié)果
為了驗證提出的方法,我們需要設計實驗,測試模型在不同數(shù)據(jù)分布上的性能。真實的圖數(shù)據(jù)中可能包含多種不同的分布偏移,這里我們考慮三種情況:人造混淆噪聲(Artificial Transformation)、跨圖領域遷移(Cross-Domain Transfer)、動態(tài)圖時序泛化(Temporal Evolution)。下表展示了本文使用的6個數(shù)據(jù)集以及對應的分布偏移的形式。
處理人造混淆噪聲 我們首先考慮Cora和Amazon-Photo數(shù)據(jù)集,對其引入噪聲,方法如下:采用兩個隨機初始化的GCN,第一個GCN基于原始節(jié)點特征生成節(jié)點真實標簽,第二個GCN基于節(jié)點標簽和環(huán)境id生成冗余特征,于是節(jié)點的特征為原始特征和冗余特征的拼接。
對每個數(shù)據(jù)集,我們將環(huán)境id設為1-10,總共生成10張圖,第一張用于訓練,第二張驗證,其余的作為測試。如此下來,訓練集與測試集之間就被引入了分布偏移,原始特征與標簽的關(guān)系是對于環(huán)境不變的,而冗余特征與標簽的關(guān)系則是環(huán)境敏感的。
我們考慮使用GCN作為預測模型主干,下圖分別顯示了使用傳統(tǒng)方法(Empirical Risk Minimization,ERM,即直接優(yōu)化訓練數(shù)據(jù)的損失)與本文提出方法(EERM)在Cora和Amazon數(shù)據(jù)集上8個測試圖的準確率(Accuracy)對比。這里,我們重復了20次實驗(使用不同網(wǎng)絡初始化),展示了準確率的分布情況。可以看到,EERM在絕大多數(shù)情況下明顯好于ERM。
跨圖領域遷移 一種典型的分布外泛化場景是圖數(shù)據(jù)上的領域泛化(Domain Generalization)。這里我們考慮Twitch-Explicit和Facebook-100數(shù)據(jù)集,它們都是社交網(wǎng)絡,分別包含了7張和100張子圖。我們使用一部分圖作為訓練集,另一部分作為測試。由于每一張子圖都是來自不同地區(qū)的社交網(wǎng)絡,而且大小、密度、標簽分布都不盡相同,因此訓練數(shù)據(jù)與測試數(shù)據(jù)就天然存在分布偏移。
對于Twitch數(shù)據(jù)集,我們使用子圖DE作為訓練集,ENGB作為驗證集,其余作為測試集。由于是二分類問題且類別標簽不均衡,所以我們使用ROC-AUC作為評測指標。下圖顯示了分別使用GCN、GAT、GCNII作為網(wǎng)絡主干,ERM與EERM在5個測試圖上的性能對比。可以看到,EERM在幾乎所有情況下都超越了ERM。
對于Facebook數(shù)據(jù)集,我們考慮使用多個圖進行訓練。具體的,我們考慮三種訓練子圖的組合。下表顯示了使用不同訓練子圖的組合,在三個測試圖(Penn,Brown,Texas)上的準確率對比。同樣,EERM在絕大部分情況下超越了ERM,取得了最高7.2%的提升。
動態(tài)圖時序外推 另一種典型的分布偏移來源于時序動態(tài)圖,訓練數(shù)據(jù)往往是歷史某個階段收集的片段,測試數(shù)據(jù)則來源于未來。隨著時間的推移,圖數(shù)據(jù)可能發(fā)生變化。這里我們進一步考慮兩種不同的情況。第一種情況對應動態(tài)的時序snapshot,我們考慮Elliptic數(shù)據(jù)集,它一共包含49個graph snapshot,每一個記錄了在一段時間內(nèi)的金融交易,任務是識別網(wǎng)絡中的非法節(jié)點。我們把snapshot按時間順序排列,使用前5個作為訓練,第6-10個作驗證,其余的作為測試集(把每相鄰的4個合并為一組)。
我們使用F1分數(shù)作為評測指標,下圖顯示了使用GraphSAGE和GPRGNN作為主干模型的效果對比。可以看到,EERM顯著好于ERM,取得了平均9.6%/10.0%的提升。
接著我們考慮第二種情況,隨著時間的推移,圖中的節(jié)點和連邊會發(fā)生變化。這里我們考慮OGBN-Arxiv數(shù)據(jù)集,其中每個節(jié)點是論文。我們按論文的發(fā)表時間將節(jié)點分為訓練集和測試集。為了引入分布偏移,我們擴大訓練節(jié)點和測試節(jié)點的時間間隔:使用2011前發(fā)表的論文作為訓練集,2011-2014年發(fā)表的論文作為驗證集,2014年之后的為測試集。下表展示了時間在2014-2016/2016-2018/2018-2020年的測試節(jié)點上的測試準確率。可以看到,隨著時間的推移(分布偏移進一步擴大),模型的性能都呈現(xiàn)下降趨勢,但ERM的下降趨勢更為明顯。這也說明,EERM能夠有效提升模型對分布偏移的魯棒性。
討論與總結(jié)
近期也有不少工作關(guān)注神經(jīng)網(wǎng)絡在圖結(jié)構(gòu)數(shù)據(jù)上的泛化性和魯棒性,例如針對未知的特征/結(jié)構(gòu)[6]和不同大小的圖[7,8]。然而,他們主要專注于整圖級別(graph-level)任務,而不能很好的解決節(jié)點級(node-level)任務。整圖級別任務與節(jié)點級任務所關(guān)注的重點與技術(shù)難點是不同的。對于整圖級別任務,每個輸入圖(如分子圖)是一個預測對象,此時可以將輸入圖視為而圖的標簽視為。基于此,從同一環(huán)境里采樣的觀測數(shù)據(jù)可以視為獨立同分布的數(shù)據(jù)樣本,此時就可以直接使用一般的針對歐式數(shù)據(jù)的相關(guān)方法來處理分布外泛化的問題。
相比之下,對于節(jié)點級任務,每個節(jié)點是需要預測的對象,他們被一張大圖(如社交網(wǎng)絡)所連接,節(jié)點之間的相互依賴使得樣本是非獨立同分布產(chǎn)生的,因此不能直接使用現(xiàn)有的方法來解決節(jié)點級任務的分布外泛化。本文作為此方向的第一個探索性工作,提出了一個新的視角來分析和解決節(jié)點級任務的分布外泛化問題。此外,本文的方法和理論也適用于整圖級別的任務,甚至擴展到其他數(shù)據(jù)(如圖像、文本),特別是處理從單一觀測環(huán)境泛化到未知的測試環(huán)境。
除此之外,最近的一些研究工作探索了開放環(huán)境下的外推與泛化問題,例如對于輸入空間的新特征[9]或未知用戶節(jié)點[10]。本質(zhì)上,這些工作同樣考慮了分布外泛化問題,只不過這里的分布偏移來自于擴張的輸入空間。另外,[9,10]所提出的方法利用了圖神經(jīng)網(wǎng)絡作為一種顯式的機制,實現(xiàn)利用已有實體的表示向量來推理計算新實體的表示向量。相比之下,本文的方法則對應一種隱式的機制,通過數(shù)據(jù)生成分布背后的不變性,來引導模型實現(xiàn)泛化與外推。
針對本文有任何疑問,或希望進一步討論可以發(fā)送郵件至echo740@sjtu.edu.cn,或加微信myronwqt228。
參考文獻
[1] Mateo Rojas-Carulla, Bernhard Sch?lkopf, Richard E. Turner, and Jonas Peters. Invariant models for causal transfer learning. Journal of Machine Learning Research, 2018.?
[2] Martín Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. CoRR, abs/1907.02893, 2019.?
[3] Peter Bühlmann. Invariance, causality and robustness. CoRR, abs/1812.08233, 2018.?
[4] Keyulu Xu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken-ichi Kawarabayashi, and Stefanie Jegelka. How neural networks extrapolate: From feedforward to graph neural networks. In ICLR, 2021.?
[5] Gilad Yehudai, Ethan Fetaya, Eli A. Meirom, Gal Chechik, and Haggai Maron. From local structures to size generalization in graph neural networks. In ICML, 2021.?
[6] Beatrice Bevilacqua, Yangze Zhou, and Bruno Ribeiro. Size-invariant graph representations for graph classification extrapolations. In ICML, 2021.?
[7] Qitian Wu, Chenxiao Yang, and Junchi Yan. Towards open-world feature extrapolation: An inductive graph learning approach. In NeurIPS, 2021.?
[8] Qitian Wu, Hengrui Zhang, Xiaofeng Gao, Junchi Yan, and Hongyuan Zha. Towards open-world recommendation: An inductive model-based collaborative filtering approach. In ICML, 2021.
特別鳴謝
感謝 TCCI 天橋腦科學研究院對于 PaperWeekly 的支持。TCCI 關(guān)注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學術(shù)熱點剖析、科研心得或競賽經(jīng)驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺已發(fā)表或待發(fā)表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競爭力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯(lián)系方式(微信),以便我們在稿件選用的第一時間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的ICLR 2022 | 从因果不变性视角探讨图神经网络的分布外泛化鲁棒性的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 彩虹股份属于科技股吗 主营业务有哪些
- 下一篇: win10新更新网络连接失败怎么办啊 解