關(guān)鍵詞:隨機游走;圖模型;注意力機制;圖擴散
中圖分類號:TP391 文獻標志碼:A
0 引言(Introduction)
基于圖的深度學習方法在解決許多重要的圖問題上取得了成功[1-4],其中一些工作旨在使用注意力機制提取圖上的特征信息[5],但存在以下問題。
(1)典型的圖方法對于每個節(jié)點僅使用非常有限的鄰居節(jié)點信息,而更大的鄰域可以向模型提供更多的信息。一般的方法是通過疊加多個層達到傳遞全局信息的目的,但是過多的層會導致過度平滑問題[6-7],并且隨著層數(shù)增加,訓練難度會不斷增大。
(2)使用圖數(shù)據(jù)的挑戰(zhàn)性在于要找到正確的表達圖結(jié)構(gòu)的方式,傳統(tǒng)方法無法區(qū)分鄰居的位置關(guān)系[8],從而失去了節(jié)點相關(guān)的拓撲信息。基于注意力機制[5]的方法將鄰居特征的加權(quán)和作為中心節(jié)點的輸出特征,但只考慮了特征信息,并沒有反映節(jié)點不同的結(jié)構(gòu)。ZHOU等[9]提出在注意力機制中引入代表結(jié)構(gòu)信息的可學習向量,但向量的學習增大了參數(shù)量和模型復雜度,存在過擬合的現(xiàn)象。
綜上所述,本文提出一種隨機游走的圖擴散模型(GraphDiffusion Model with Random Walks,GDR),該算法通過隨機游走的擴散方式[10]訪問鄰居節(jié)點,確保了對每個中心節(jié)點的局部鄰域進行編碼。通過設(shè)置游走的相關(guān)參數(shù),可以滿足控制一定范圍內(nèi)鄰域信息的需求。本研究認為,該游走策略產(chǎn)生的鄰居節(jié)點包含了圖的結(jié)構(gòu),使得在注意力機制計算特征相關(guān)性的同時包含了節(jié)點的拓撲信息,并且這種擴散方法沒有增加模型的參數(shù),使訓練更加簡單。
1 問題描述(Problem statement)
對于圖中節(jié)點的分類任務(wù),建立圖G=(V,ε,A),其中V是圖中的節(jié)點集,ε是邊集,反映了節(jié)點之間的連通性。A∈RN×N 表示G 的鄰接矩陣。同時,建立矩陣H ∈RN×d 作為節(jié)點的輸入特征矩陣,其中N 表示節(jié)點數(shù),d 表示輸入特征維度。給定圖G 的輸入特征矩陣H,通過學習一個特征轉(zhuǎn)換函數(shù)f 得到輸出特征矩陣H'∈RN×d',再通過分類器對輸出特征的節(jié)點進行分類。
2 模型架構(gòu)(Model architecture)
2.1 整體架構(gòu)
本文提出的GDR模型的整體架構(gòu)如圖1所示,它由多個特征轉(zhuǎn)換模塊組成,每個模塊包含一個隨機游走層(RandomWalk Layer, RWL)和一個圖注意力層(Graph AttentionLayer, GAL)。RWL以隨機游走的擴散方式獲取各中心節(jié)點的鄰居,這些鄰居節(jié)點包含結(jié)構(gòu)上的依賴關(guān)系,并可繼續(xù)擴散到更大的鄰域。GAL通過圖上的注意力機制對RWL層輸出的節(jié)點及其鄰居進行特征轉(zhuǎn)換,注意力機制本質(zhì)上只針對鄰居節(jié)點特征加權(quán)求和,引入RWL后,節(jié)點之間通過結(jié)構(gòu)進行了區(qū)分,使模型包含更完備的信息。輸入的圖數(shù)據(jù)通過多個特征轉(zhuǎn)換模塊后,生成的最終特征進入神經(jīng)網(wǎng)絡(luò)分類器中進行節(jié)點類別的預(yù)測。
2.2 隨機游走層
典型的圖模型在采集鄰居節(jié)點時,只使用了每個節(jié)點非常有限的鄰域范圍,GraphSAGE(Graph SAmple and aggreGatE)模型[8]通過隨機采樣的方式從一階或二階鄰域中獲取節(jié)點,GAT模型[5]則直接使用一階節(jié)點,這種局部的鄰域通過疊加多層而不斷擴大;對于其中一層(圖2),作為鄰居節(jié)點的B 和C,相對中心節(jié)點A 沒有做結(jié)構(gòu)上的區(qū)分,在連接方式上相對A 是沒有差別的,而不同的連接方式正是鄰居節(jié)點拓撲結(jié)構(gòu)的差異?;谧⒁饬C制的特征學習方法最初是在自然語言的背景下開發(fā)的,旨在尋找線性文本中連續(xù)單詞的上下文位置關(guān)系,即線性的結(jié)構(gòu)。注意力機制無法對位置進行區(qū)分,典型的Transformer模型[11]在原始特征中加入位置信息,是一種人為設(shè)計的特征,文獻作者沒有對其進行詳細解釋。網(wǎng)絡(luò)是非線性結(jié)構(gòu),需要更豐富的鄰域范圍內(nèi)的結(jié)構(gòu)信息。SAN(StructuralAttention Network)模型[9]則是在圖中引入了一個結(jié)構(gòu)化向量,讓模型在訓練過程中自動地進行結(jié)構(gòu)化信息的學習,但增加了模型的參數(shù)。
整個特征轉(zhuǎn)換模塊分為隨機游走和圖注意力計算兩個階段。隨機游走階段首先利用輸入的鄰接矩陣A,根據(jù)公式(1)計算轉(zhuǎn)移矩陣T,其次根據(jù)公式(2)及參數(shù)α 和k 計算得到概率矩陣P,最后根據(jù)P 中的概率進行隨機游走生成鄰居節(jié)點集合S。圖注意力階段則從上一階段輸出的節(jié)點集合S中選擇中心節(jié)點的鄰居,首先利用輸入的特征矩陣H,根據(jù)公式(4)計算注意力系數(shù),其次根據(jù)公式(5)對輸入特征進行轉(zhuǎn)換,最后通過神經(jīng)網(wǎng)絡(luò)分類器進行類別預(yù)測。
3 實驗(Experiment)
本文通過轉(zhuǎn)導學習和歸納學習,將GDR模型與其他基準模型在節(jié)點分類任務(wù)中的性能進行了比較。本節(jié)總結(jié)了實驗設(shè)置、結(jié)果,并對GDR模型的相關(guān)參數(shù)進行了簡要分析。
3.1 數(shù)據(jù)集描述
實驗使用的數(shù)據(jù)集如表1所示,在3個引文數(shù)據(jù)集上預(yù)測文檔類別以評估本文模型的轉(zhuǎn)導學習能力,包括Citeseer、Cora和Pubmed[14]。數(shù)據(jù)集包含文檔的特征向量集合及文檔之間的引用鏈接列表,以此構(gòu)造出鄰接矩陣和特征矩陣,訓練過程中使用了圖中所有節(jié)點的特征向量。歸納任務(wù)部分采用了生物醫(yī)學領(lǐng)域的蛋白質(zhì)相互作用(Protein-Protein Interaction,PPI)數(shù)據(jù)集[8],由對應(yīng)于不同蛋白質(zhì)組織的圖組成,共有24張圖,每個圖的平均節(jié)點數(shù)為2 372個,節(jié)點特征維度為50維,由位置基因集、基序基因集和免疫學特征組成,該數(shù)據(jù)集中一個節(jié)點同時擁有多個標簽,并且用于測試任務(wù)的圖在訓練期間沒有被使用。
3.2 實驗設(shè)置
對于引文數(shù)據(jù)集,采用一個特征轉(zhuǎn)換模塊的GDR模型。其中,RWL的重啟概率α 為0.1,擴散系數(shù)k 為10。GAL則主要參考了GAT模型進行設(shè)置,采用8個注意力層組成的多頭注意力結(jié)構(gòu),每個頭輸出8個維度的特征(總共64個特征),使用ELU激活函數(shù)進行非線性變換,最后一層是softmax分類。由于引文數(shù)據(jù)集較小,模型采用了λ=0.005的L2正則化方法。
PPI數(shù)據(jù)集用于評價模型在跨圖上的歸納學習能力,采用了3個特征轉(zhuǎn)換模塊,RWL的重啟概率α 為0.1,擴散系統(tǒng)k 為20,GAL的設(shè)置同樣參照了GAT 模型的做法,以方便對比模型改進的效果。前兩個GAL是4個注意力層組成的多頭注意力結(jié)構(gòu),每個頭輸出256維的特征(總共1 024維),采用ELU激活函數(shù)。最后一層是單頭的注意力層,輸出維度為類別數(shù),由于每個節(jié)點屬于多個類別,因此分類采用了logistic函數(shù)。
3.3 基準模型
對于引文數(shù)據(jù)集,選用的對比模型為GAT和GCN[7],以及GCN在使用多階Chebyshev截斷時的效果。對于PPI數(shù)據(jù)集,除了與注意力模型GAT進行對比,還比較了GraphSAGE模型中提出的4種不同的聚合方法。這些方法在小范圍鄰域內(nèi)采樣節(jié)點并通過某種聚合函數(shù)計算輸出特征,如GrapshSAGE-GCN采用了GCN的圖卷積操作作為聚合函數(shù),GraphSAGE-mean 直接取所有采樣特征的平均值,GraphSAGE-LSTM將采樣的鄰居特征隨機排序后輸入LSTM進行聚合,GraphSAGE-pooling將節(jié)點特征經(jīng)過全連接的神經(jīng)網(wǎng)絡(luò)后進行最大池化聚合。
3.4 實驗結(jié)果
表2給出了在3個引文數(shù)據(jù)集上針對測試節(jié)點的分類準確率。從表2中可以看出,在對比方法中,GDR模型在3個引文數(shù)據(jù)集上都取得了最高準確率,并且注意力機制中的參數(shù)設(shè)置完全參考GAT模型,可以認為性能的提升來自RWL的隨機游走策略,表明通過擴散得到的鄰居節(jié)點提供了更多的信息。相比于SAN模型通過引入訓練時可學習的結(jié)構(gòu)化向量的方法,游走策略沒有給模型增加訓練時的參數(shù),因此不容易過擬合。
通過調(diào)整重啟概率α 和擴散系數(shù)k,觀察游走策略的作用。圖4顯示了α 取值的變化對測試集準確率的影響。雖然針對不同數(shù)據(jù)集的最佳取值略有不同,但是起伏變化基本一致,重啟概率在0.1~0.2時表現(xiàn)最佳,表明一定的重啟概率使游走兼顧了局部與全局結(jié)構(gòu),并帶來了性能的提升。由于不同的圖數(shù)據(jù)具有不同的結(jié)構(gòu),因此在訓練時需要針對具體的數(shù)據(jù)集調(diào)整該參數(shù)取值。
圖5顯示了擴散系數(shù)k 取不同值時,對模型分類結(jié)果的影響。隨著k 值的增加,準確率呈現(xiàn)上升趨勢,證明隨機游走的擴散策略對模型性能的提升是有益的。如圖5所示,使用適當?shù)膋 值(例如取值為10)有效地近似精確結(jié)果已經(jīng)足夠,更大的取值帶來的效果提升不明顯,而且可能會因為更多鄰居節(jié)點的加入而導致GAL的過擬合。
表3總結(jié)了不同方法在PPI數(shù)據(jù)集0e1GpSYVhisAwu5eLy9rC1LhCZSC57I8NQ0rxnLOzvQ=上的歸納學習結(jié)果比較,采用GraphSAGE模型中的評價方法測試了模型在未見節(jié)點上的micro F1值。由于GDR模型采用的是有監(jiān)督的學習,所以本研究比較的是GraphSAGE模型的有監(jiān)督版本。本文方法在對比中取得了明顯優(yōu)勢。GraphSAGE模型在選取鄰居節(jié)點時,使用的采樣策略沒有區(qū)分節(jié)點之間的相對關(guān)系,證明了本文方法在獲取結(jié)構(gòu)信息上具有優(yōu)勢。GAT模型在獲取鄰域時,只使用一階鄰居節(jié)點,而GDR模型在游走時可通過設(shè)置擴散系數(shù)傳播到更大的鄰域,體現(xiàn)了擴散的效果。SAN模型與本文方法取得了同樣的得分,但需要訓練一個有參數(shù)的向量矩陣,增加了模型的復雜度,而本文方法在訓練上更加簡單。
4 結(jié)論(Conclusion)
本文提出了一種基于隨機游走策略的圖擴散模型,可用于圖節(jié)點的分類。該模型在經(jīng)典的圖注意力模型基礎(chǔ)上增加了一個隨機游走層,能有效地提取圖數(shù)據(jù)中更大鄰域范圍內(nèi)的節(jié)點信息,使注意力機制同時考慮了節(jié)點的結(jié)構(gòu)和特征。在多個引文數(shù)據(jù)集及一個蛋白質(zhì)網(wǎng)絡(luò)數(shù)據(jù)集上的實驗表明,該模型對節(jié)點的分類結(jié)果優(yōu)于現(xiàn)有的經(jīng)典模型,證明了隨機游走策略的有效性。
作者簡介:
周安眾(1986-),男,碩士,講師。研究領(lǐng)域:大數(shù)據(jù)技術(shù),人工智能。
謝丁峰(1978-),男,碩士,副教授。研究領(lǐng)域:大數(shù)據(jù)技術(shù),人工智能。