王 聰,王 杰,劉全明,梁吉業(yè)
1.山西大學 計算機與信息技術學院,太原030006
2.山西大學 計算智能與中文信息處理教育部重點實驗室,太原030006
傳統(tǒng)的監(jiān)督學習,如支持向量機(support vector machine,SVM)、神經(jīng)網(wǎng)絡(neural networks,NN)等,通常需要大量良好的標記樣本對模型進行訓練,以便獲得較好的模型泛化能力。同時,在處理高維數(shù)據(jù)(如視頻、語音、圖像、文檔)時,訓練一個好的監(jiān)督模型所需要的標記樣本數(shù)量會進一步增長。這使得傳統(tǒng)監(jiān)督學習很難應用于一些缺乏標記訓練樣本的任務中。
半監(jiān)督學習(semi-supervised learning,SSL)[1]是近十多年發(fā)展起來的一種新型機器學習方法,其思想是在標記樣本數(shù)量很少的情況下,通過在模型訓練中引入無標記樣本來避免傳統(tǒng)監(jiān)督學習在訓練樣本不足(學習不充分)時出現(xiàn)性能(或模型)退化的問題。半監(jiān)督學習的研究具有重要的實用價值,因為在許多實際應用中,無標記樣本的獲取相對容易,而標記樣本的獲取成本往往較高。因此,減少標記樣本的使用能夠大幅縮減人力、時間和資源的開銷,從而降低生產(chǎn)成本。同時在標記樣本數(shù)量減少數(shù)十或數(shù)百倍(甚至更多)的情況下,半監(jiān)督算法能夠取得與傳統(tǒng)監(jiān)督學習算法相近甚至更好的效果,提升生產(chǎn)效率。半監(jiān)督學習的研究具有重要的理論價值,它是介于傳統(tǒng)監(jiān)督學習和無監(jiān)督學習之間的一種新型機器學習方法,是對傳統(tǒng)機器學習理論的拓展和補充。
圖半監(jiān)督學習(semi-supervised learning on graphs)作為半監(jiān)督學習的一個重要分支,在理論和實踐上引起了極大的關注。給定一個由少量標記節(jié)點和大量未標記節(jié)點組成的圖,它的目標是為圖中的未標記節(jié)點分配標簽。生成對抗網(wǎng)絡(generative adversarial networks,GAN)[2]由于其強大的表征能力已經(jīng)被廣泛應用于半監(jiān)督學習,但它在圖半監(jiān)督學習任務上的工作較少?,F(xiàn)有的工作主要關注在低密度區(qū)域生成未標記樣本來削弱子圖之間的信息傳播,從而使決策邊界更清晰,如GraphSGAN[3]通過GAN 在子圖之間的低密度區(qū)域生成未標記樣本,減少子圖邊緣節(jié)點的影響,從而提高圖半監(jiān)督分類效果。但受限于標記樣本過少,監(jiān)督信息的不足仍在一定程度上限制了其性能。針對這個問題,本文提出了一種新的圖半監(jiān)督學習框架(semi-supervised learning on graphs using adversarial training with generated sample,SemiGATDS),它由圖嵌入模塊、兩個生成器、一個分類器和一個判別器五部分組成。其中,圖嵌入模塊將圖映射到特征空間,在特征空間中,一個生成器生成服從真實樣本分布的標記樣本,另一個生成器生成與真實樣本分布不同的未標記樣本。分類器負責為給定的樣本分配標簽,判別器用來區(qū)分樣本標簽對是否來自真實分布。通過生成器、判別器和分類器的對抗訓練,當模型達到穩(wěn)態(tài)時,生成的標記樣本擴充了標記樣本訓練集,生成的未標記樣本削弱了子圖邊緣節(jié)點的影響,迫使分類界限更加清晰,從而提高了分類效果。本文在Cora、Citeseer、Pubmed[4]三個數(shù)據(jù)集上評估了SemiGATDS 的分類性能,并討論了不同數(shù)量的標記樣本和不同生成樣本比例對算法的影響,實驗結果驗證了本文方法的有效性。
半監(jiān)督學習旨在利用大量未標記樣本來提高模型性能。半監(jiān)督學習有以下幾種范式:生成式方法[5]、基于支持向量機的半監(jiān)督學習算法[6]、基于分歧的方法[7]和圖半監(jiān)督學習[8-9]。其中,由于圖半監(jiān)督學習解釋性強、性能優(yōu)越,受到很多的關注,它的核心思想是數(shù)據(jù)集中每個樣本對應于圖中一個節(jié)點,若兩個樣本之間的相似度很高(或相關性很強),則對應的節(jié)點之間存在一條邊,邊的“強度”(strength)正比于樣本之間的相似度(或相關性)。利用圖上的鄰接關系將標簽從標記樣本向無標記樣本傳播。
關于圖半監(jiān)督學習的研究大致分為兩類,基于圖的拉普拉斯正則化框架[10]是其中一個重要的研究方向。Zhou 等人[11]通過在損失函數(shù)中使用基于圖的拉普拉斯正則化項,在圖上平滑標簽信息。文獻[12]提出了一種基于高斯隨機場和形式化圖拉普拉斯正則化框架的算法。Belkin 等人[13]提出了一種利用幾何的邊緣分布理論進行半監(jiān)督學習的正則化方法ManiReg。另一個研究方向是將半監(jiān)督學習與圖嵌入[14]相結合。文獻[15]首次將深度神經(jīng)網(wǎng)絡引入圖的拉普拉斯正則化框架中進行半監(jiān)督學習和圖嵌入。Yang 等人[16]提出了聯(lián)合圖嵌入學習和節(jié)點標簽預測模型Planetoid。DeepWalk[17]是第一個關于圖嵌入的工作,作為一種無監(jiān)督圖嵌入學習方法,如果與分類器相結合,很容易轉化為半監(jiān)督學習基線模型。圖卷積神經(jīng)網(wǎng)絡(graph convolutional network,GCN)[18]是第一個用于圖半監(jiān)督學習的圖卷積模型,它在這個問題上表現(xiàn)出了強大的能力。
GAN 作為一種功能強大的深度生成模型,最早用來表示自然圖像上的數(shù)據(jù)分布,通過生成器和判別器的互相博弈學習產(chǎn)生更好的輸出。最近在半監(jiān)督學習框架中展示了它們的能力[19]。半監(jiān)督生成對抗網(wǎng)絡(semi-supervised generative adversarial networks,SGAN)[20]最早是在計算機視覺領域提出的。SGAN用分類器取代了GAN 中的判別器。為了防止生成器過度訓練,Salimans 等人[21]首次提出特征匹配損失,將GAN 應用于關于“K+1”類的半監(jiān)督學習。Li 等人[22]認識到生成器和判別器可能無法同時達到最優(yōu),并且無法控制生成樣本的語義信息,提出了Triple-GAN。隨著標記樣本數(shù)量的減少,Triple-GAN 的性能改善更加顯著,這表明生成的樣本標簽對可以有效地用于訓練分類器。文獻[23]意識到生成器也存在同樣的問題,從理論上解釋了為什么生成與真實樣本分布不同的樣本可以提高SSL 性能。通過精心設計生成器的損失,生成器可以生成與真實樣本分布不同的樣本,迫使分類器的決策邊界位于不同類的數(shù)據(jù)流形之間,這反過來又增強了分類器的泛化能力。
基于GAN 的圖半監(jiān)督學習的研究工作較少,如GraphSGAN。這項工作的主要思想是在子圖之間的密度間隙生成未標記樣本,削弱不同類之間的信息傳播,但是用于訓練的標記樣本過少仍然是制約其性能的關鍵。針對這個問題,本文提出了Semi-GATDS,該算法同時生成服從真實樣本分布的標記樣本和與真實樣本分布不同的未標記樣本,以提高圖半監(jiān)督學習性能。
設G=(V,E)表示一個圖,其中V代表節(jié)點集,E?V×V代表邊集。假設每個節(jié)點vi與d維實值特征向量wi∈Rd和標簽yi∈{1,2,…,K}相關聯(lián)。如果節(jié)點vi的標簽yi未知,則節(jié)點vi是一個未標記節(jié)點。設標記節(jié)點集合為VL,未標記節(jié)點集合為VU=V?VL。通常,有|VL|?|VU|。由此,本文形式化地定義圖上的半監(jiān)督學習問題,給定部分標記圖G=(VL?VU,E),使用與每個節(jié)點和圖相關聯(lián)的特征w來學習函數(shù)f,預測圖中未標記節(jié)點的標簽。
本文模型框架如圖1 所示,SemiGATDS 由五部分組成,分別是圖嵌入模塊、兩個生成器、一個分類器和一個判別器?;贕AN 的模型不能直接應用于圖數(shù)據(jù),因此,遵循文獻[3]的設置,首先使用網(wǎng)絡表示學習算法(本文使用TADW(text-associated deepwalk)[24]對節(jié)點原始特征進行預處理)學習每個節(jié)點的潛在分布表示qi,然后將潛在分布表示qi與原始特征向量wi拼接,即xi=(wi,qi)。在模型中,將生成標記樣本的生成器稱為gG,它接受真實標簽y和隨機噪聲z作為輸入,并生成以y為標簽的服從真實樣本分布的標記樣本;生成未標記樣本的生成器稱為bG,它接受隨機噪聲z為輸入,生成與真實樣本分布不同的未標記樣本;分類器C,為給定的樣本分配標簽;判別器D,判斷樣本標簽對是否來自真實樣本分布。
圖1 SemiGATDS 模型示意圖Fig.1 Illustration of SemiGATDS
在模型中,考慮“K+1”類分類問題。gG首先通過真實標簽y和200 維隨機噪聲z,采樣于先驗分布Pz(z)(實驗中使用均勻分布噪聲z)生成樣本xgG~PgG(x|y,z),與條件標簽y組成標記樣本。接著bG通過隨機噪聲z生成未標記樣本xbG~PbG(x|z) 。C接受四種不同類型的樣本:標記樣本xL、未標記樣本xU、來自gG的生成樣本xgG和來自bG的生成樣本xbG,并依據(jù)條件分布PC(y|x)為它們產(chǎn)生偽標簽。對于帶標簽的數(shù)據(jù)xL和gG生成的樣本xgG,期望C為它們分配正確的標簽(為xL分配標簽yL,為xgG分配它的條件標簽y)。對于bG生成的樣本xbG和未標記樣本xU,期望C將它們分別識別為第“K+1”類(即“假”類)和前K類其中之一。D接受C和gG生成的樣本標簽對(xC,yC) 和(xgG,ygG),以及標記樣本(xL,yL)作為輸入,并將標記樣本標簽對視為真樣本,而來自gG和C的樣本標簽對均為假樣本。定義各個部分的損失如下:
將gG的損失函數(shù)定義為:
其中,PD(x,y)表示樣本標簽對(x,y),來自真實樣本分布的概率,最小化損失函數(shù),使得gG生成更接近真實樣本分布的標記樣本。
bG的損失函數(shù)定義為:
為特征匹配損失,它最小化了bG生成樣本與真實樣本中心點之間的距離,以確保生成器在類和類之間的密度間隙中生成樣本。
為pull-away term[11],它具有增加生成特征的多樣性從而增加生成熵的效果,這里可以鼓勵bG生成更多不同類別的樣本。其中N是批次大小,xi、xj是同一批次的樣本。λ0是用來平衡兩個損失的超參數(shù),實驗中將其設置為1。
C的損失函數(shù)由四部分組成:
C的總損失是:
其中,損失和損失分別表示標記樣本和gG生成樣本的交叉熵損失,損失迫使C將未標記樣本識別為前“K”類,而損失迫使C將bG生成的樣本識別為“K+1”類。λ1、λ2、λ3是用于平衡每個損失的超參數(shù),實驗中將這三個超參數(shù)均設置為0.5。
最后,判別器D的損失由三部分組成,分別為:
D的總損失是:
其中,損失迫使判別器D增大真實標記樣本對被視為真類的概率,損失和迫使判別器D減小生成樣本標簽對被視為真類的概率,β1、β2是用于平衡每個損失的超參數(shù),實驗中將這兩個超參數(shù)均設置為1。
在訓練過程中,SemiGATDS 由三組對抗訓練組成:(1)gG通過生成以標簽y為條件的標記樣本來與D進行對抗訓練;(2)C通過為未標記樣本生成置信度高的標簽與D進行對抗訓練;(3)bG通過生成未標記樣本與C進行對抗訓練。生成的未標記樣本迫使分類界限更清晰,生成的標記樣本對擴充了監(jiān)督信息,模型從這兩種生成樣本中學習。詳細的訓練過程如算法1 所示。
算法1SemiGATDS 訓練算法
假設給定圖G=(VL?VU,E),其中節(jié)點總數(shù)為s(包含標記節(jié)點和未標記節(jié)點),節(jié)點特征維度d,圖嵌入表示維度e,節(jié)點類別數(shù)為k。本文算法的時間復雜度主要由計算節(jié)點的潛在分布表示和訓練生成器、分類器、判別器四個神經(jīng)網(wǎng)絡產(chǎn)生。其中圖嵌入算法TADW 的時間復雜度為O(s2)。
本文使用的生成器、分類器、判別器均采用全連接神經(jīng)網(wǎng)絡結構。神經(jīng)網(wǎng)絡時間復雜度依據(jù)浮點運算次數(shù)計算,一次浮點運算可以定義為一次乘法和一次加法。生成器和判別器均是擁有兩個隱藏層的神經(jīng)網(wǎng)絡,分別具有(c1,c1)個神經(jīng)元,bG生成器輸入為隨機噪聲z,維度為t1,輸出為節(jié)點特征和節(jié)點圖嵌入表示拼接后的維度d+e,第一層執(zhí)行t1×c1次乘加操作,第二層執(zhí)行c1×c1次乘加操作,最后一層執(zhí)行c1×(d+e) 次操作,總共執(zhí)行t1×c1+c1×c1+c1×(d+e)次操作,假設每批次訓練m個樣本,bG生成器的總操作次數(shù)為m(t1×c1+c1×c1+c1×(d+e)),時間復雜度為O(m×c1×(d+e))。gG生成器輸入為隨機噪聲z與標簽y的拼接,標簽y經(jīng)過編碼后其維度為t2,因此gG生成器的輸入維度為t1+t2,其每批次訓練m個樣本,gG生成器的總操作次數(shù)為m((t1+t2)×c1+c1×c1+c1×(d+e)),時間復雜度為O(m×c1×(d+e))。判別器D的輸入維度即節(jié)點特征維度為d+e,輸出為真假即維度為1,其每批次訓練總操作次數(shù)為m(c1×(d+e)+c1×c1+c1),時間復雜度為O(m×c1×(d+e))。分類器C輸入維度即節(jié)點特征維度為d+e,擁有5個隱藏層的神經(jīng)網(wǎng)絡,分別具有(c1,c1,c2,c2,c2)個神經(jīng)元輸出為類別個數(shù),其維度為k。以此類推,每批次訓練總操作數(shù)為m((d+e)×c1+c1×c1+c1×c2+2c2×c2+c2×k),時間復雜度為O(m×c1×(d+e))。
綜上,SemiGATDS 算法總的時間復雜度為O(s2)+O(m×c1×(d+e))。
數(shù)據(jù)集統(tǒng)計匯總如表1 所示。在引文網(wǎng)絡數(shù)據(jù)集Citeseer、Cora 和Pubmed 中,節(jié)點是文檔,邊是引文鏈接。標記節(jié)點數(shù)表示用于訓練的標記節(jié)點的個數(shù)。每個文檔都有以詞袋模型(bag-of-words model)表示的特征,并根據(jù)主題賦予特定的標簽。
表1 數(shù)據(jù)集統(tǒng)計Table 1 Dataset statistics
為了避免過度調(diào)整網(wǎng)絡體系結構和超參數(shù),所有實驗均使用默認設置進行訓練與測試。具體地說,分類器C有5 個隱藏層,分別具有(500,500,250,250,250)個神經(jīng)元。隨機層采用零均值高斯噪聲,隱藏層輸入標準差為0.05,輸出標準差為0.5。生成器bG具有兩個500 個神經(jīng)元的隱藏層,每個隱藏層后面都有一個批歸一化層,輸出層使用Tanh 激活函數(shù)。生成器gG和bG具有相同結構,不同的是前者以噪聲z和真實標簽y的拼接作為輸入。判別器也采用和生成器相同的隱藏層結構,只是對輸入層和輸出層作了相應的調(diào)整。模型由ADAM 進行優(yōu)化,所有參數(shù)均使用Xavier初始化方法。
為了公平比較,實驗遵循文獻[16]中的設置,對于每個類,選擇20 個樣本(文檔)作為標記樣本用于訓練,同時選擇1 000 個樣本作為測試樣本。所有實驗結果取10 次隨機拆分的平均值。在這3 個數(shù)據(jù)集中,將提出的方法SemiGATDS 與4 類方法進行了比較:
(1)基于正則化的方法LP(label propagation)[11]、ICA(iterative classification algorithm)[25]和ManiReg[13];
(2)基于圖嵌入的方法DeepWalk[17]、SemiEmb[15]和Planetoid[16];
(3)基于圖卷積的方法Chebyshev[26]、GCN[18];
(4)基于GAN 的方法Triple-GAN[22]、GraphSGAN[3]。
由于原始Triple-GAN 并未用于圖,本文在圖上重新實現(xiàn)了Triple-GAN,并復現(xiàn)了GraphSGAN,在3個數(shù)據(jù)集上進行了實驗。其中Triple-GAN 的生成器生成服從真實樣本分布的標記樣本,而GraphSGAN的生成器生成與真實樣本分布不同的未標記樣本。
本文在3 個數(shù)據(jù)集上均訓練了200 個epoch。表2 顯示了SemiGATDS 與上述方法對比的實驗結果。
表2 分類準確率匯總Table 2 Summary of results of classification accuracy 單位:%
實驗結果表明,本文方法優(yōu)于所有基于正則化、圖嵌入以及圖卷積的方法,且比Cora、Citeseer 和Pubmed 數(shù)據(jù)集上的最佳結果分別提升了2.4 個百分點、0.2 個百分點和0.4 個百分點。同時由表可知,在Cora 和Citeseer 數(shù)據(jù)集上,基于GAN 的方法均優(yōu)于其他方法,也驗證了將生成對抗網(wǎng)絡用于圖半監(jiān)督學習任務的有效性。而GraphSGAN 的效果優(yōu)于Triple-GAN,說明產(chǎn)生的與真實樣本分布不同的未標記樣本對分類效果影響更大。SemiGATDS 結合兩者的優(yōu)點,同時生成的服從真實樣本分布的標記樣本和與真實樣本分布不同的未標記樣本,共同對模型產(chǎn)生了影響,獲得了比Triple-GAN 和GraphSGAN 更好的結果,從而驗證了SemiGATDS 的有效性。
為了進一步了解SemiGATDS 使用不同數(shù)量的標記樣本訓練時的表現(xiàn),本文通過改變每類選擇的標記樣本的數(shù)量n獲得不同的訓練集。表3~表5 顯示了3 個數(shù)據(jù)集上的實驗結果。由表可知,隨著有標記樣本比例的增加,用于訓練模型的數(shù)據(jù)增加,模型能夠學到的信息越多,從訓練數(shù)據(jù)中得到的模型的分類性能越好。以Cora 數(shù)據(jù)集為例,當n為10 時,Triple-GAN、GraphSGAN 和SemiGATDS 分類準確率分別為76.4%、82.9%和83.5%;當n為20 時,它們的分類準確率上漲到81.3%、84.0%和85.4%。并且當n值相同時,SemiGATDS 所獲得的結果仍然好于GraphSGAN 和Triple-GAN。同 樣的,在Citeseer 和Pubmed 數(shù)據(jù)集上也可以觀察到相同的結果,說明生成的標記樣本可以擴充圖半監(jiān)督學習中的標記樣本訓練集,生成的未標記樣本可以強制決策邊界位于正確的位置。這兩種生成樣本同時起作用,使Semi-GATDS 獲得了更好的效果。
表3 Cora 數(shù)據(jù)集上不同數(shù)量標記樣本下的分類準確率Table 3 Classification accuracy under different number of labeled samples on Cora dataset
表4 Citeseer數(shù)據(jù)集上不同數(shù)量標記樣本下的分類準確率Table 4 Classification accuracy under different number of labeled samples on Citeseer dataset
表5 Pubmed 數(shù)據(jù)集上不同數(shù)量標記樣本下的分類準確率Table 5 Classification accuracy under different number of labeled samples on Pubmed dataset
在Cora 數(shù)據(jù)集上,本文對比了Triple-GAN、GraphSGAN 和SemiGATDS 的分類準確率與epoch的關系,實驗取了前20 個epoch 的結果,如圖2 所示。
圖2 算法在Cora 數(shù)據(jù)集上分類準確率與訓練周期的關系Fig.2 Relationship between classification accuracy and training period of algorithms on Cora dataset
通過觀察發(fā)現(xiàn)了兩個不同的訓練階段:
一階段:三個模型訓練波動比較大。推測是因為在初始階段生成的樣本質量不高,對模型造成了干擾。
二階段:模型趨于穩(wěn)定,SemiGATDS 明顯超過了Triple-GAN 和GraphSGAN。從分類器的角度看,gG生成的標記樣本用于擴充圖半監(jiān)督學習中標記樣本訓練集,bG生成的未標記樣本減少了密度間隙中鄰近節(jié)點的影響。兩種生成樣本的共同作用,使得分類器得到了更好的分類效果。
為了探究生成的未標記樣本和標記樣本的比例對實驗結果的影響,本文對比了模型在3 種數(shù)據(jù)集Cora、Citeseer、Pubmed 上,不同生成比例下的性能,如表6~表8 所示。
表6 SemiGATDS 在Cora 數(shù)據(jù)集上不同生成比例(未標記樣本∶標記樣本)的分類準確率Table 6 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Cora dataset
表7 SemiGATDS 在Citeseer數(shù)據(jù)集上不同生成比例(未標記樣本∶標記樣本)的分類準確率Table 7 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Citeseer dataset
表8 SemiGATDS 在Pubmed 數(shù)據(jù)集上不同生成比例(未標記樣本∶標記樣本)的分類準確率Table 8 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Pubmed dataset
從表中結果可以得出如下結論:在Citeseer、Pubmed 兩個數(shù)據(jù)集上,當生成的未標記樣本和標記樣本比例為1∶1 時,模型的效果更好。在Cora 數(shù)據(jù)集上,當生成的未標記樣本和標記樣本比例為1∶2 時模型的效果更好,但生成的未標記樣本和標記樣本比例為1∶1 和比例為1∶2 的效果相差不大,因此最終選取1∶1 的比例作為所有實驗的基準。
現(xiàn)有基于GAN 的圖半監(jiān)督學習算法能有效提升半監(jiān)督學習的分類性能,但標記樣本過少仍是其面臨的主要困難。針對這個問題,本文提出了一種基于GAN 的圖半監(jiān)督學習框架SemiGATDS,它通過生成器、分類器以及判別器之間的對抗訓練,同時生成服從真實樣本分布的標記樣本和與真實樣本分布不同的未標記樣本,當模型達到穩(wěn)態(tài)時,生成的標記樣本可以擴充標記樣本訓練集,生成的未標記樣本可以減少密度間隙中鄰近節(jié)點的影響,使決策邊界更清晰,從而提高圖半監(jiān)督分類的效果。在多個數(shù)據(jù)集上本文提出的SemiGATDS 均優(yōu)于現(xiàn)有的方法,進一步討論了不同數(shù)量的標記樣本和不同生成樣本比例對SemiGATDS 性能的影響,實驗結果驗證了該方法的有效性。