楊廣乾,李金龍
(中國科學技術大學 計算機科學與技術學院,安徽 合肥230026)
圖結構化數(shù)據(jù)廣泛存在于現(xiàn)實世界中,圖神經(jīng)網(wǎng)絡(GNN)已被證明可以有效地學習圖結構化數(shù)據(jù)背后的知識[1-2]。圖神經(jīng)網(wǎng)絡基于傳播機制,通過聚合圖中節(jié)點的鄰居信息來學習潛在表示,可以用于下游任務,例如節(jié)點分類[2-3]、圖分類[4-5]、連接預測[6-7]等。
受自然語言處理和計算機視覺中注意力機制的啟發(fā),研究人員也開始探索圖結構學習中的注意力機制。最廣泛使用的注意力機制是圖注意力網(wǎng)絡(Graph Attention Network)[8],它已被證明具有出色的性能。圖注意力在消息傳遞過程中計算每對鄰居的注意力分數(shù),以衡量節(jié)點的重要性,使得圖中的歸納學習成為可能?;谶@項工作,后續(xù)工作[9-11]又進行了許多對圖注意力的研究。
在經(jīng)典的圖注意力機制中,節(jié)點之間的注意力只在直接連接的(一階)鄰居節(jié)點之間計算,然后通過堆疊層[8,11]以隱式地獲得高階注意力。然而,這種范式存在注意依賴問題。具體來說,后面的層的注意力計算依賴于前面的層,導致近程鄰居(例如直接相連的鄰居)的注意力分數(shù)會對遠程鄰居產(chǎn)生影響。圖1 中展示了層間注意力的皮爾遜相關系數(shù)隨訓練輪數(shù)的變化??梢钥吹?,相關系數(shù)在初始下降后一直呈上升趨勢,并收斂到較高的值,這表明層間的注意力得分高度相關。此問題表明,一階鄰居的注意力分數(shù)往往與高階鄰居的注意力分數(shù)呈正相關,使得注意力計算更難準確地模擬遠程依賴關系。
圖1 訓練損失與層間注意力相關系數(shù)
為了解決此問題,受自然語言處理模型的啟發(fā),本文提出了一種新的注意力計算機制。通過直接顯式地計算高階鄰居之間的注意力系數(shù)來建模它們的依賴關系。該方法有兩個優(yōu)點:(1)不同階鄰居之間的注意力可以聯(lián)合建模;(2)圖中潛在的遠程依賴關系更容易被捕獲。然而,使用高階鄰接矩陣的注意力計算會導致計算復雜度呈指數(shù)級增加。為了提高計算效率,該方法利用啟發(fā)式剪枝算法來探索更有價值的節(jié)點。
在計算多階鄰居的注意力矩陣之后,另一個主要問題是傳播鄰居信息。在不同的傳播方法中,基于多尺度的方法旨在通過K 階鄰接矩陣[12-15]直接聚合信息。但是,固定的聚合過程是不可學習的,無法適應不同圖的特性,因此往往會導致次優(yōu)的結果。
為了使聚合過程更加靈活,提出了一種自適應的方法來直接聚合來自不同跳鄰居的信息。受動態(tài)路由方法[16]的啟發(fā),該方法采用鄰居路由機制來實現(xiàn)。具體來說,鄰居路由利用與不同跳鄰居表示的耦合度來改進聚合過程,以突出了不同階鄰居表示的重要性,而不同階鄰居表示的傳播會更新節(jié)點的最終表示。通過迭代地執(zhí)行這樣的傳播操作,最終可以獲得感知不同階鄰居重要性的最終表示。
該模型命名為直接多尺度圖注意力網(wǎng)絡(DMGAT)。該方法包括兩個過程:(1)通過K 階鄰接矩陣直接計算注意力分數(shù),以直接建模當前節(jié)點和鄰居之間的依賴關系,其中使用剪枝算法來抑制計算量的指數(shù)增長;(2)采用自適應聚合直接傳播信息,其中傳播過程由鄰居路由算法實現(xiàn)。節(jié)點表示最終將用于線性層的分類。
本文的主要貢獻如下:
(1)提出了一種新的注意力計算機制,該機制直接計算多跳的注意力,并利用注意力修剪算法來避免指數(shù)的計算復雜度。
(2)提出了一種自適應多尺度聚合機制。該機制不使用固定的聚合公式,而是通過鄰居路由方法基于圖的特性來學習更靈活的表示。
模型評估的實驗結果表明,該方法在節(jié)點多分類任務上具有出色的性能。
本節(jié)簡要介紹一些以前的相關工作。使用方法主要與圖注意力和多尺度網(wǎng)絡有關。首先利用圖注意力機制通過冪鄰接矩陣直接計算注意力,然后使用從表示中學習的靈活聚合機制來傳播高階信息。
注意力機制在很多深度學習領域得到了廣泛的應用,比如自然語言處理[17-18]、計算機視覺[19-20]等,后來注意力機制又出現(xiàn)了很多變種。Graph Attention Network[8]首先提出了圖中的注意力機制,在聚合信息的同時計算不同鄰居的注意力分數(shù)。然而,后來的研究表明,基于注意力的機制在注意力計算[10]中存在過擬合問題。在此基礎上,后來的研究者們提出了更多的改進[21-24],不同的方法一般使用不同的注意力計算方法。
與這些工作不同,本工作直接通過K 階鄰接矩陣計算高階鄰居之間的注意力。通過直接的注意力計算,模型可以更好地模擬遠程鄰居之間的依賴關系,以減輕過擬合。此外,進一步設計了一種高階鄰居剪枝算法,以降低注意力計算的復雜度。
傳統(tǒng)的圖神經(jīng)網(wǎng)絡聚合來自一跳鄰居的信息,并通過堆疊層獲得多跳信息。后來的一些研究人員試圖將多跳信息卷積在一個單層中,這種方法取得了出色的性能。例如混合鄰接矩陣的多個冪進行卷積[11],以規(guī)避GCN[2]譜圖近似方法造成的建模能力降低[15];在塊 Krylov 子空間形式中推廣譜圖卷積和深度GCN,以利用多尺度信息,賦予網(wǎng)絡更強的表達能力;LanczosNet[25]利用圖拉普拉斯算子的低秩近似,快速地近似計算矩陣冪;MixHop[26]為一層中的每個鄰接冪學習多個權重矩陣,并為每一跳調(diào)整潛在維度,從而在一層學習GCN 無法表示的高階微分操作;N-GCN[13]進一步提出使用權重來縮放GCN不同階鄰接矩陣的表征;mLink[14]提出了一種節(jié)點聚合方法,該方法將輸入的封閉子圖迭代地變換為不同的尺度;MSGE[27]基于隨機游走確定多個不同尺度的子圖,然后采用圖嵌入方法來學習超節(jié)點嵌入;MSGA[28]設計了一個多尺度自我表達模塊,用于從每一層獲得更具辨別力的表示。
本工作設計了一種鄰居路由機制[16],該機制通過節(jié)點的不同跳的表示來計算其與最終表示的耦合度,以便節(jié)點可以根據(jù)其特性獲得不同的路由系數(shù)。通過迭代傳播,最終表示可以感知各階鄰居的重要性,從而優(yōu)化了聚合過程。
本節(jié)首先介紹文章中使用的符號。假設觀察到一個圖G=(V,E),其中V 是節(jié)點集,E 是邊集,N=|V|是節(jié)點集中的節(jié)點數(shù)目。對于任意一個節(jié)點vi∈V,其都有一個對應的特征向量,記作xi∈RF,從而構成一個特征矩陣X∈RN×F,其中F 表示特征的維度。進一步,圖G 的鄰接矩陣可以表示為A∈RN×N,其中如果兩個節(jié)點vi和vj間存在邊,則Aij=1,否則Aij=-∞。作為鄰居矩陣的拓展,可以用計算A 的k次冪的方法得到k 階鄰接矩陣的概念,記作Ak,它表示兩個節(jié)點是k 階相連的,并且對應的值表示節(jié)點之間的路徑數(shù)。此外,使用一個恒等矩陣作為自連接矩陣,即A0=I,以保留節(jié)點自身的信息。
該模型旨在學習直接融合了多階信息的節(jié)點表征。圖2 闡述了模型的流程。它采用節(jié)點的特征矩陣X 和所有的1 階,2 階,…,K 階鄰接矩陣A1,A2,…,AK∈RN×N作為輸入,再使用權重共享的注意力處理來自不同階的鄰居,然后使用自適應的多跳聚合來傳播來自不同階的信息,最后輸出包含了所有k 階鄰居信息的最終表征,其中k=1,2,…,K,K是一個表示鄰居最大階數(shù)的超參數(shù)。
直接計算目標節(jié)點vi與它的鄰居節(jié)點vj間的注意力系數(shù),包括直接相連的鄰居和高階鄰居。為了進一步降低計算復雜度,模型采用剪枝算法來去除部分不重要的邊。
2.2.1 直接高階注意力
首先,輸入節(jié)點特征矩陣X ∈RN×F被一個線性方程轉換成潛在表示H∈RN×F′,如式(1)所示:
其中Watt∈RF′×F是一個可訓練的權重矩陣。
圖2 模型框架圖
給定一個潛在空間表示H 后,兩個k 階鄰接節(jié)點的注意力系數(shù)由注意力函數(shù)a(hi,hj)計算,如式(2)所示:
其中wu,wv∈RF′是在不同階鄰居間共享的可訓練參數(shù),||是拼接操作,LeakyReLU 是非線性激活函數(shù)。注意力系數(shù)用于度量hi和hj間的依賴關系,從而使模型基于節(jié)點對間的耦合度學到不同的關系。
接下來,基于高階連接重新調(diào)整注意力系數(shù)eijk ,如式(3)所示:
其中p 是系數(shù)dropout 概率,ξ 是一個0 到1 間的隨機數(shù)。確切地說,k 階鄰居的注意力系數(shù)以p 的概率被選擇。注意這里使用剪枝的高階鄰接矩陣進行注意力計算,該方法將在下一小節(jié)介紹剪枝算法。
為了注意力系數(shù)的可比性以及數(shù)值的穩(wěn)定性,使用softmax 函數(shù)來對注意力系數(shù)進行歸一化,如式(4)所示:
通過式(4)可以得到所有鄰居節(jié)點的歸一化注意力矩陣,該矩陣反映了哪些鄰居應該被分配更多的注意力。進一步,基于該矩陣得到用于聚合過程的轉移矩陣,如式(5)所示:
與GAT 不同,本方法中來自不同階的鄰居共享同一個線性變換的權重矩陣。這種共享的方式有多個好處:首先,相同的注意力矩陣會將所有的節(jié)點特征映射到相同的潛在空間中,從而更容易捕獲不同階鄰居間的特征相似性以減輕過擬合;其次,與非共享的方式相比,共享將注意力計算需要的參數(shù)量降低到原來的1/k,從而顯著降低了內(nèi)存的消耗。并且與DAGN[29]不同,該方法直接計算包括高階鄰居在內(nèi)的不同階鄰居的注意力系數(shù),而不是只計算直接相連的鄰居再進行擴散過程。
2.2.2 高階邊剪枝
此前使用高階鄰接矩陣來進行直接注意力計算。通過分析可以得出,計算的復雜度與邊的數(shù)量呈線性關系。然而,隨著鄰接矩陣冪階數(shù)的增長,邊的數(shù)量會隨之指數(shù)增長,并且聚合的信息會變得更加嘈雜。為了解決這個問題,提出了一個邊剪枝算法,如式(6)所示:
其中t 是一個可以基于圖性質(zhì)調(diào)整的超參數(shù)。采用這種剪枝形式是因為高連通的hub 節(jié)點往往傾向于連接不同的群體[30],因此它們更可能對于當前節(jié)點來說是噪聲。這種方式降低了高階鄰接矩陣的邊的數(shù)目,并鼓勵節(jié)點探索更可能未被傳播過的節(jié)點信息。
聚合過程可以基于圖的結構和特征調(diào)整為一個可學習的過程,從而對應不同圖的特性?;谶@種思想,設計了一個基于路由算法的自適應聚合過程。
2.3.1 鄰居路由
然后,使用非線性的tanh 函數(shù)在行維度激活父膠囊H′。進一步,計算父膠囊與子膠囊間耦合系數(shù)來更新對數(shù)先驗概率,如式(9)、式(10)所示:
事實上,鄰居路由本質(zhì)上是在度量每階鄰居表征對于下游任務的作用,其中對下游任務幫助更大的表征將會被分配更高的路由系數(shù)。該結構中涉及較少的參數(shù)量,故有利于緩解過擬合的問題。
聚合了多階鄰居信息的最終表征H′將會使用一個全連接層,從而用于分類任務,如式(11)所示:
其中W 是一個用于分類的線性變換,ReLU 是一個非線性激活函數(shù)[31],softmax 被用于行維度進行分類。
該方法使用交叉熵損失函數(shù)來優(yōu)化模型,如式(12)所示:
其中YL是有標簽的節(jié)點集合。整個模型使用梯度下降算法進行優(yōu)化。
該方法對于所有鄰接矩陣都使用稀疏表示,鄰接冪的計算通過系數(shù)矩陣乘法進行,復雜度為O(N|E|)[32]。對于高階冪的剪枝算法可以將復雜度降低到與A 相同的階數(shù),即O(|E|),因此注意力系數(shù)的計算可以近似為O(K|E|)。
該方法對幾個數(shù)據(jù)集在半監(jiān)督節(jié)點分類任務上進行了大量的實驗。首先介紹所使用數(shù)據(jù)集的一些統(tǒng)計數(shù)據(jù),并詳細說明實驗的一些設定。為了驗證提出的模型 DMGAT 的性能,本節(jié)將其與節(jié)點分類任務的表現(xiàn)最好的模型進行比較,進行了更多的參數(shù)實驗和消融實驗,以研究模型的超參數(shù)敏感性和不同部分的有效性。最后對實驗結果進行一些分析。
本節(jié)首先介紹一些實驗使用的數(shù)據(jù)集的基本信息。對于節(jié)點分類任務,使用三個基準引文網(wǎng)絡進行半監(jiān)督學習,分別是Cora、Citeseer 和Pubmed[33]。在這些數(shù)據(jù)集中,節(jié)點表示文章,邊表示引用關系。每個節(jié)點的特征是文章的詞袋特征向量,其使用布爾值來表明一些關鍵詞在對應的文檔中是否存在,且每個文檔都有一個類標簽。表1 展示了這些數(shù)據(jù)集的統(tǒng)計數(shù)據(jù),表中展示了各數(shù)據(jù)集的節(jié)點數(shù)、邊數(shù)、特征數(shù)以及類別數(shù)。
表1 數(shù)據(jù)集統(tǒng)計數(shù)據(jù)
實驗基于Pytorch Geometric[34]框架進行,并使用了和GCN 相同的數(shù)據(jù)劃分以保證比較的公平性。
對于節(jié)點分類任務,直接對其進行端到端的訓練,并使用SGD[35]優(yōu)化算法訓練模型以最小化其交叉熵損失。在這三個數(shù)據(jù)集上,學習率統(tǒng)一被設置成0.01,L2正則化權重分別被設置成0.001、0.003和0.003,隱藏特征的維度F′被設置成8,注意力頭數(shù)M 被設置成8。在驗證集的交叉熵損失上使用提前停止策略以終止訓練。為了減輕過擬合,在輸入特征和隱藏層中使用了dropout[36]方法。對于所有的數(shù)據(jù)集,訓練模型50 次并使用平均準確率作為最終結果。
本節(jié)中報告實驗的結果,并將其與一些性能最好的模型進行比較以驗證模型的表現(xiàn)。
3.3.1 基準模型介紹
為了分析模型在節(jié)點分類任務上的性能,本節(jié)將它與一些模型在半監(jiān)督分類任務上進行比較。為了保證比較的公平性,在進行實驗時選擇了一些對數(shù)據(jù)集有相同劃分的模型的結果??紤]的代表性模型包括經(jīng)典的模型、基于注意力的模型以及基于多尺度的模型。典型的基準模型包括GCN、GAT、MixHop、GDC[37]、N-GCN[38]以及S2GC[39]。這里所有的結果都是由原文中引用得到。
3.3.2 實驗結果
表2 中總結了三個數(shù)據(jù)集上節(jié)點分類準確性結果,其中基準模型的結果是直接從原文中引用而來的。
表2 三個數(shù)據(jù)集節(jié)點分類準確性結果
如表2 所示,該方法在節(jié)點分類與性能最好的方法相比有著優(yōu)異的性能。在Cora 和Citeseer 數(shù)據(jù)集上,該模型相比其他模型取得了最好的表現(xiàn),相比第二的模型分別提高了1.3%和0.7%;而在Pubmed數(shù)據(jù)集上,它也取得了第二的準確率80.6%。結果表明該方法在Cora 和Citeseer 數(shù)據(jù)集上有著優(yōu)異的性能。為了進一步驗證模型的有效性,本節(jié)設計了消融實驗,并詳盡分析了使用不同階數(shù)鄰接矩陣時的表現(xiàn)。
本節(jié)中分別替換了直接注意力和自適應多尺度路由以研究每個組件的作用。
為了研究提出模塊的有效性,本節(jié)分別移除了直接注意力、自適應多尺度聚合以進行實驗,實驗結果如表3 所示。實驗統(tǒng)一使用K=2,M=8 進行比較。直接注意力的消融實驗的實現(xiàn)方式是:使用式(1)中變換一次然后進行多次信息傳播。
表3 實驗結果
如表3 所示,相比之下,直接注意力機制在Cora、Citeseer 和Pubmed 上分別提高了0.7%,1.5%和0.8%。特別地,在Citeseer 數(shù)據(jù)集上提升最大,這表明該數(shù)據(jù)集相比其他數(shù)據(jù)集更依賴高階鄰居的信息。而自適應的聚合過程持續(xù)地提升了每個數(shù)據(jù)集的表現(xiàn),在Cora、Citeseer 和Pubmed 上分別提高了0.8%、0.6%和0.6%,實驗結果驗證了直接注意力和自適應多尺度聚合的有效性。除此之外,在直接注意力計算的權重共享設置上同樣進行了消融實驗,實驗結果表明該方法不僅降低了參數(shù)計算量,同時在一定程度上提升了準確率,其在Cora、Citeseer 和Pubmed 上分別提升了1.1%、0.4%和1.1%,實驗結果表明網(wǎng)絡模型過大的容量同樣會導致過擬合問題。
為了研究階數(shù)K 和節(jié)點分類準確率間的關系,本節(jié)分別設置K=2,3,4,并將結果與GAT 的表現(xiàn)進行對比。結果表明,相比傳統(tǒng)的堆疊方式,該方法降低了層數(shù)堆疊導致的性能下降問題。
圖3 使用t-SNE 降維的可視化
使用t-SNE[40]進行可視化以闡述模型的有效性(如圖3 所示),圖中不同顏色的點表示不同節(jié)點的潛在表征H′。如圖所示,在GAT 中使用2 階鄰居時表現(xiàn)最好,當使用3 階和4 階鄰居時,不同標簽的節(jié)點逐漸混合在一起;而在本模型中,使用3 階和4 階鄰居時,不同標簽的節(jié)點仍能分辨得相對清楚,表明該方法在一定程度上緩解了過平滑問題[41]。
本節(jié)使用了一個小提琴圖來描述不同階數(shù)路由系數(shù)的密度。這里所有數(shù)據(jù)集均使用階數(shù)K=2,研究路由系數(shù)的總體分布,最終數(shù)據(jù)如圖4 所示。
圖4 不同數(shù)據(jù)集上的路由系數(shù)
如圖4 所示,相比節(jié)點自身的信息,1 階和2階的鄰居對于表征的貢獻更大。這一現(xiàn)象符合直覺,因為在此前實驗中可以觀察到,卷積的方法相比單獨的多層感知機會顯著提升分類效果。特別在Citeseer 數(shù)據(jù)集上,路由系數(shù)相比其他數(shù)據(jù)集最穩(wěn)定,其最大值相比最小值近高出0.006 個點。相反地,Cora 數(shù)據(jù)集上的路由系數(shù)相對不穩(wěn)定,并且一些情況下節(jié)點自身的路由系數(shù)會高出鄰居的系數(shù)。而在Pubmed 數(shù)據(jù)集中,不同階表征看起來不具有明顯的區(qū)分性。
本文提出了一個圖神經(jīng)網(wǎng)絡模型DMGAT。為解決經(jīng)典注意力計算中存在的注意力依賴問題,提出了一個全新的直接注意力機制。該機制顯式地計算高階鄰居間的注意力系數(shù),使用邊剪枝算法以抑制指數(shù)增長的計算復雜度。在此基礎上,進一步提出一個自適應聚合機制,該機制使用鄰居路由算法以直接傳播多尺度信息。大量的節(jié)點分類實驗結果表明了方法有著優(yōu)異的性能。