鄧維斌,王智瑩,高榮壕,王國胤,胡 峰
(重慶郵電大學 計算智能重慶市重點實驗室,重慶 400065)
文本分類目前廣泛應用于情感分析、語言推理、主題分類、垃圾郵件檢測、新聞過濾等領域,已經(jīng)成為自然語言處理中一項重要任務。在傳統(tǒng)單標簽文本分類中,每個樣本只對應一個標簽,且各標簽之間相互獨立。大數(shù)據(jù)時代文本信息日益豐富使得類別劃分越來越詳細,一個樣本往往與多個標簽相關,同時,標簽之間常存在聯(lián)系,對于這種分類任務稱為多標簽文本分類[1]。
多標簽文本分類已廣泛應用于許多真實場景,如涉及多個學科的論文、討論多個主題的博客、包含多種情感的評論等。與二分類相比,多標簽文本分類需要對文本有更加深入地理解,以提供更全面、準確的標簽預測[2]。對于多標簽文本分類問題,人們首先考慮到的是將多標簽文本分類任務轉化成技術已經(jīng)逐漸成熟的單標簽分類任務。然而,這種方法認為標簽之間相互獨立,忽略了標簽之間的關聯(lián)性。隨著深度學習的發(fā)展,學者們提出了許多基于深度學習的多標簽分類模型,并取得了較好的成效。其中,基于序列到序列(Seq2seq)的模型在多標簽分類領域得到廣泛應用,該模型使用編碼器提取文本信息,通過解碼器按順序預測標簽,顯著提高了多標簽分類的性能[3]。隨著注意力機制出現(xiàn),研究者將注意力機制引入Seq2seq模型可以有效地捕獲文本的重要信息,但是,傳統(tǒng)注意力機制往往只重視文本語義信息提取而忽略標簽語義,導致模型不能充分捕獲標簽語義和標簽間的關聯(lián)信息。
目前,存在許多模型用來處理多標簽分類問題,但仍然有不足之處。其中,如何有效學習標簽之間的依賴關系,并將標簽間關聯(lián)信息與文本信息進行更加自適應地交互成為一個關鍵問題。針對上述問題,本文提出了一種融合注意力與CorNet的多標簽文本分類算法。為了捕獲標簽間的依賴關系,利用基于標簽特征矩陣的圖注意力網(wǎng)絡學習標簽之間的相關性,并且在標簽預測層之后添加CorNet模塊捕捉標簽間關聯(lián)關系以增強標簽預測。設計“文本-標簽”注意力機制,將文本特征和標簽間關系特征進行點乘運算獲得每個單詞對當前標簽的權重,并將權重和文本表示結合得到最終的文檔表示,實現(xiàn)標簽間關聯(lián)信息與文本特征信息自適應交互。
多標簽文本分類算法大致分為兩類:基于傳統(tǒng)機器學習的算法和基于深度學習的算法。
傳統(tǒng)機器學習算法主要包括問題轉換方法和算法自適應方法兩大類[4]。問題轉換方法是將多標簽文本分類任務轉換為多個獨立的二進制分類問題。Zhang等提出二元相關(binary relevance,BR)方法就是問題轉換方法[5],該方法通過給每個標簽建立一個單獨的分類器實現(xiàn)多標簽文本分類,但BR方法忽略了標簽之間的相關性使得模型性能較低。為了捕捉標簽相關性,Read等提出了分類器鏈(classifier chain,CC)方法[6],將多個二進制分類器連接在一起,每個分類器使用來自前一個分類器的預測作為輸入。這種方法的缺點是不同的標簽順序可能會產生不同結果,連接過程也意味著CC方法無法并行化,因此,在處理大型數(shù)據(jù)集時會產生較高的計算成本。算法自適應方法是對傳統(tǒng)的單標簽分類算法進行改進來解決多標簽分類問題。代表性算法有Clare等提出的ML-DT(multi-Label decision tree)方法[7],其基本思想是通過使用熵的信息增益遞歸地構建決策樹來處理多標簽問題。Elisseeff等提出排名支持向量機(ranking support vector machine,rank-SVM)方法[8],基于與SVM的特性構造類似于學習系統(tǒng)的支持向量機來處理多標簽問題,但這種系統(tǒng)的表達能力很弱。Younes等將K最近鄰(KNN)算法應用到多標簽分類問題[9],并且考慮了標簽之間的依賴關系。然而這些傳統(tǒng)的機器學習方法在處理多標簽問題時不能充分挖掘文本語義信息,大大降低了多標簽分類的精度。
深度學習的快速發(fā)展使基于深度學習的模型效果有了很大的提升。深度學習方法廣泛應用于自然語言處理領域。Jacovi等提出CNN模型處理文本分類對文本進行最大程度的特征提取[10],從而提高了文本分類的效果。Liu等提出XML-CNN模型使用CNN設計了一個動態(tài)池處理文本分類,在池化層和輸出層之間加了一個隱藏層來降低標簽維度以減少計算量,并且改進了損失函數(shù),采用二元交叉熵損失函數(shù),使得文本分類效果得到明顯的提升[11]。雖然基于CNN 的算法在多標簽分類任務中取得了不錯的研究成果,但這類算法僅僅從局部提取文本語義信息,缺乏對全局信息的考量,沒有考慮標簽之間的關聯(lián)性。宋攀等提出利用神經(jīng)網(wǎng)絡構造矩陣刻畫標簽之間的依賴關系,同時可以解決標簽缺失問題[12]。Chen等提出的CNN-RNN[13]和Yang等提出的序列生成模型(SGM)[14]通過使用編碼器和解碼器分別對文本進行編碼及生成可能的編碼序列,但這類方法過于依賴標簽的順序,標簽順序不同時可能會產生不同的結果。You等提出AttentionXML使用自注意力機制來捕獲與每個標簽最相關的文本但忽略了標簽信息[15]。Xiao等提出的LSAN模型提出標簽注意力機制學習特定于標簽的文本表示,將標簽語義信息引入到模型中[16]。Yao等提出用圖卷積網(wǎng)絡(GCN)對文本進行分類,基于單詞共現(xiàn)和文檔-單詞關系為語料庫構建一個異構圖,并使用圖卷積神經(jīng)網(wǎng)絡聯(lián)合學習單詞和文檔嵌入[17]。盡管圖卷積神經(jīng)網(wǎng)絡已取得了較好的效果,但GCN仍然缺少重要的結構特征,無法更好地捕捉節(jié)點之間的相關性或依賴性。
為了進一步提升多標簽文本分類模型的性能,提出了一種融合注意力與CorNet的多標簽文本分類模型MLACN,通過圖注意力網(wǎng)絡和CorNet模塊充分捕獲標簽間的語義依賴,同時利用“文本-標簽”注意力機制,將標簽之間的語義關系與文本上下文語義信息進行交互,獲取基于標簽語義信息的文本特征表示。模型如圖1所示。
圖1 模型框架
為了更好地捕捉文本雙向語義關系,采用Bi-LSTM[18]從前后兩個方向分別提取文本上下文語義信息,并計算每個單詞的隱表示
(1)
將文本中每個單詞的隱表示串聯(lián)得到整體文本表示
(2)
(3)
其中:αij是hi的歸一化系數(shù);wj是注意參數(shù),每個標簽的wj不同;M(s)是多標簽注意力機制下特定于標簽的文本表示。
圖注意力網(wǎng)絡[20]將標簽數(shù)據(jù)的節(jié)點特征和鄰接矩陣作為輸入?;跇撕灅嬙爨徑泳仃?模型通過學習鄰接矩陣確定圖,從而學習標簽的相關性。圖注意力網(wǎng)絡通過將標簽間的關聯(lián)關系構建成加權圖,以便鄰接矩陣和注意權重一起表示標簽的相關性。
2.4.1 構建鄰接矩陣 通過計算標簽的成對共現(xiàn)來構造鄰接矩陣。頻率向量是向量F∈Rl,Fi是整個訓練集中標簽i的頻率,通過頻率向量對共現(xiàn)矩陣L進行歸一化。
Ladj=L/F
(4)
其中:Ladj∈Rl×l是鄰接矩陣;F∈Rl是單個標簽的頻率向量。
圖2 GAT模型
eij=a(WHi,WHj)
(5)
其中:W是可訓練參數(shù);a是前饋神經(jīng)網(wǎng)絡的可訓練參數(shù);eij表示節(jié)點j對于節(jié)點i的重要性,并且節(jié)點i必須是節(jié)點j的一階鄰居。注意力系數(shù)計算公式為
(6)
其中:LeakyReLU為非線性激活函數(shù);αij為標簽j相對于標簽i的歸一化注意系數(shù);k∈Ni表示節(jié)點i的所有一階鄰域節(jié)點。
根據(jù)式(6)的注意力系數(shù),對特征進行加權求和
(7)
GAT中還加入了多頭注意力機制,將經(jīng)過K頭注意力機制計算后的特征向量進行拼接,對應的輸出特征向量表達為
(8)
(9)
經(jīng)過GAT計算后的向量記作Hgat∈Rc×d,其中:c表示標簽數(shù)量;d表示標簽的特征尺寸。
文本中每個單詞對于不同的標簽起到的作用是不同的。為了強化標簽之間的語義聯(lián)系,將標簽語義信息與文本上下文語義信息進行交互,獲得基于標簽語義的文本特征表示,設計“文本-標簽”注意力計算每個單詞的重要度,通過將文本特征H與標簽特征向量Hgat進行點乘計算獲得文本和標簽之間的匹配得分A
(10)
文本內容對不同標簽的重要程度是不同的,為了建立文本和標簽之間的關系,將上一層得到的A轉置乘以文本的隱表示,得到標簽對應的文本表示
(11)
M(s)和M(l)都是標簽對應的文檔表示,但是兩者的側重點不同。前者側重于文檔內容,后者側重于標簽內容。為了充分利用這兩個部分的優(yōu)勢,使用自適應融合機制,以自適應地從中提取信息,并得到最終的文檔表示。
將M(s)和M(l)作為全連接層的輸入,通過全連接層獲得兩個權重向量β,γ來確定上述兩個注意力機制的重要性。
β=Sigmoid(M(s)W1)
γ=Sigmoid(M(l)W2)
(12)
其中:W1,W2∈R2k是可訓練參數(shù);βj和γj分別表示多標簽注意力機制和“文本-標簽”注意力機制在對第j個標簽構建最終的文本表示時的重要程度
βj+γj=1
(13)
然后,根據(jù)融合權重獲得第j個標簽的最終文本表示為
(14)
所有標簽的最終文檔表示為M。
本文使用多層感知機實現(xiàn)標簽預測,預測第i個標簽出現(xiàn)的概率通過式(15)獲得
yx=W4f(W3MT)
(15)
其中:W3、W4是參數(shù)矩陣;函數(shù)f為RELU激活函數(shù)。
圖3 CorNet模型
(16)
F(x)=W6δ(W5σ(x)+b1)+b2
(17)
其中:W5、W6是權重矩陣;b1、b2是偏置;σ、δ分別是Sigmoid和ELU激活函數(shù)。
MLACN 使用二元交叉熵損失(binary cross entropy loss)[23]作為損失函數(shù)計算損失值如下
(18)
實驗采用AAPD、RCV1-V2和Reuters-21578多標簽分類數(shù)據(jù)集。
AAPD[14]:該數(shù)據(jù)集為北京大學大數(shù)據(jù)研究院提供的公開英文數(shù)據(jù)集。數(shù)據(jù)集主要包括從網(wǎng)站上收集的55 840篇計算機科學領域論文摘要與相對應的主題。一篇論文摘要可能包含多個主題,總計54個主題詞。
RCV1-V2[24]:該數(shù)據(jù)集是由 Lewis 等提供的公開英文數(shù)據(jù)集,由路透社有限公司為研究人員提供的800 000多條人工分類的新聞通訊報道組成。每篇新聞報道包含多個主題,總計103個主題。
Reuters-21578[25]:該數(shù)據(jù)集中的文件是1987年從路透社收集的。這曾經(jīng)是自1996年以來從事文本分類研究人員的熱門數(shù)據(jù)集。根據(jù)路透社22173預覽版改編,現(xiàn)在包含21 578份文檔。
模型評價指標采用漢明損失(Hamming loss,記為HL)[26],精確率(precision,記為P),召回率(recall,記為R)和Micro-F1(記為F1)[27]。在這些指標中,Hamming Loss反映了分類錯誤的標簽數(shù)目,該指標值越小,則分類性能越好;精確率用來統(tǒng)計預測標簽集中預測正確的標簽所占比例,值越大說明分類性能越好;召回率表示樣本真實標簽集中被預測到的標簽的比例;Micro-F1值表示精確率和召回率的加權平均,該指標值越大,則分類性能越好。
為了充分驗證提出算法的有效性,選擇以下8種算法作為對比算法。
BR[5]:該算法提出將多標簽分類任務轉換為多個二進制分類任務。
CC[6]:基于一系列二進制分類任務來解決多標簽分類任務。
CNN[10]:主要利用卷積神經(jīng)網(wǎng)絡來學習密集的特征矩陣以捕獲文本局部語義信息。
CNN-RNN[11]:使用CNN和RNN獲得局部和全局語義,并對標簽之間的關系進行建模。
SGM[14]:一種將多標簽分類任務視為序列生成任務的模型,并將Seq2seq用作多類分類器。
LSAN[16]:利用標簽注意力機制建立特定于標簽的文本信息,同時使用自適應融合機制將標簽信息與文本信息融合。
AttentionXML[15]:利用多標簽注意力機制捕獲每個標簽最相關的文本。
ML-Reasoner[28]:該模型使用二元分類器預測標簽,同時提出一種迭代推理機制學習標簽之間的信息來避免過度依賴標簽順序。
使用Glove[29]預訓練詞向量對每個數(shù)據(jù)集的文本和標簽進行初始化,詞嵌入維度k=300。批處理大小為64,整個模型使用Adam[30]進行訓練,初始學習率為0.001,設置Dropout為0.5來防止過擬合。為了避免梯度爆炸,將模型最大梯度設置為5.0。在AAPD和Reuters-21578兩個數(shù)據(jù)集上設置GAT層數(shù)為2,圖注意力頭數(shù)為4;在RCV1-V2數(shù)據(jù)集上設置GAT層數(shù)為3,圖注意力頭數(shù)為4。CorNet層數(shù)設置為2。
MLACN模型在3個數(shù)據(jù)集上和其他基準算法評價指標得分情況見表1~表3,最優(yōu)結果用粗體表示。其中HL表示漢明損失,P和R分別表示 precision和recall,(-)表示值越低模型效果越好,(+)表示值越高模型效果越好。標有*的模型表示其結果為復現(xiàn)后的結果,未標記的模型直接引用論文的結果。
表1 在數(shù)據(jù)集AAPD上的對比結果
表2 在數(shù)據(jù)集RCV1-V2上的對比結果
表3 在數(shù)據(jù)集Reuters-21578上的對比結果
從實驗結果可以看出,本文提出的模型在更具挑戰(zhàn)性的AAPD數(shù)據(jù)集上的性能顯著優(yōu)于所有基線模型。特別是在Hamming Loss以及Micro-F1兩個指標上取得了最好的性能。MLACN模型與最常見的基準模型BR比較可以減少30.7%的Hamming Loss,提升13.9%的Micro-F1值。提出的模型性能遠超過CNN、CNN-RNN這些傳統(tǒng)的深度學習模型。同時,MLACN模型在4個性能指標上都超過了LSAN模型。與最近的ML-Reasoner模型相比減少11.7%的損失,同時取得了最好的F1值。隨著數(shù)據(jù)集大小的增加,在RCV1-V2數(shù)據(jù)集上不同模型之間的性能差異會減小,然而與其他基本模型相比,MLACN模型的性能仍然有明顯改善。 在Reuters-21578數(shù)據(jù)集上的表現(xiàn)與RCV1-V2數(shù)據(jù)集類似,提出的模型在Hamming Loss和F1評價指標上優(yōu)于其他基準模型。MLACN模型在Micro-F1指標上獲得了最好的性能,同時在其他指標的性能上均與最先進的模型性能相近。 這些實驗結果進一步驗證了MLACN模型在數(shù)據(jù)集上表現(xiàn)的優(yōu)越性。
基于深度學習的模型在大多數(shù)指標上都優(yōu)于傳統(tǒng)機器學習算法,這是因為基于深度學習的模型能充分利用訓練集捕獲更深層次的語義信息,從而更好地處理復雜數(shù)據(jù)。傳統(tǒng)CNN方法在精確率上具有一定的競爭力,在AAPD和RCV1-V2數(shù)據(jù)集上領先于目前所有的基線模型,是由于傳統(tǒng)的CNN模型非常適合提取局部特征。 CNN模型中的最大池化層會放大局部特征,使得基于CNN的分類通常依賴于明顯的特征。當正樣本多于負樣本時,CNN更傾向于生成有利于正樣本的特征,導致CNN分類中預測結果偏向正樣本,由召回率較低可以看出這個特點。CNN-RNN模型使用CNN和RNN獲得局部和全局語義,并對標簽之間的關系進行建模,在數(shù)據(jù)集上的效果與CNN模型相比有所提升。 LSAN和MLACN模型相比于其他模型較好,原因在于其他模型均沒有單獨的將文本標注的標簽信息考慮進去,盡管SGM與AttentionXML試圖建立文本與標簽之間的聯(lián)系,但僅僅局限于對文本內容的訓練與學習,會降低尾部標簽的預測能力。MLACN模型相比于LSAN有著進一步的提升,是因為一方面,通過Bi-LSTM和多標簽注意力機制對文本特征信息進行提取;另一方面,標簽之間的聯(lián)系不再局限于特定文本的語義聯(lián)系,而是通過多層 GAT 和CorNet充分挖掘全局標簽之間的聯(lián)系以及關聯(lián)程度。從整體上看,信息的融合與標簽的關聯(lián)有著更為緊密的聯(lián)系,MLACN模型有效提取文本特征信息的同時,也能學習標簽之間的聯(lián)系,進一步體現(xiàn)了模型的優(yōu)越性。
為了驗證帶有不同層數(shù)的GAT對模型性能的影響,在AAPD、RCV1-V2和Reuters-21578數(shù)據(jù)集上進行實驗,結果分別如圖4~圖6所示。實驗結果表明,在AAPD和Reuters-21578兩個數(shù)據(jù)集上,兩層GAT的效果最好,并且多標簽分類模型的性能隨著GAT層數(shù)的增長而降低。在RCV1-V2數(shù)據(jù)集上,當GAT層數(shù)為3時模型分類效果最好,且有隨著層數(shù)加深逐漸降低的趨勢??赡艿脑蚴?,標簽關系圖節(jié)點的周圍前3層節(jié)點的信息可能對最終的分類做更多的貢獻,而隨著層數(shù)的增加,其外圍的節(jié)點信息可能會產生噪聲,干擾模型的分類效果。RCV1-V2比其他兩個數(shù)據(jù)集更大,因此需要更深層次的GAT提取標簽之間的依賴關系。
圖4 不同GAT層數(shù)的MLACN模型在AAPD數(shù)據(jù)集上的對比結果
圖5 不同GAT層數(shù)的MLACN模型在RCV1-V2數(shù)據(jù)集上的對比結果
圖6 不同GAT層數(shù)的MLACN模型在Reuters-21578數(shù)據(jù)集上的對比結果
為了進一步驗證模型各組件的有效性,本文在3個數(shù)據(jù)集上進行了3組消融實驗,實驗結果如表4~表6所示。
表4 AAPD消融實驗結果
表5 RCV1-V2消融實驗結果
表6 Reuters-21578消融實驗結果
1)Without CorNet表示沒有使用CorNet模塊,僅通過GAT獲取標簽之間的依賴關系;
2)Without GAT表示沒有使用GAT模型,僅通過CorNet模塊增強標簽預測捕獲標簽之間的依賴關系;
3)Without SL-ATT表示沒有使用“文本-標簽”注意力機制,沒有將文本和標簽語義進行交互。
在3個數(shù)據(jù)集上的消融實驗結果顯示,Without CorNet模型和Without GAT模型與MLACN模型相比F1值均有所降低,表明CorNet模塊和GAT可以捕獲標簽的依賴關系, 提升模型分類效果。在AAPD和Reuters-21578數(shù)據(jù)集上,Without SL-ATT模型與MLACN模型相比F1值分別降低了1.4%和0.3%,表明“文本-標簽”注意力機制對提升模型性能是有利的,可以對文本和標簽語義進行交互,更好地提取文本和標簽語義。而在RCV1-V2數(shù)據(jù)集上,Without SL-ATT模型與MLACN模型相比結果相差不大,原因可能是RCV1-V2數(shù)據(jù)集較大,文本和標簽信息比較豐富,僅僅通過點乘運算的方法使得文本和標簽進行交互效果甚微,需要設計更深層次的交互機制進一步提取文本和標簽信息。
從整體消融實驗結果來看,MLACN模型能夠有效地融合各個組件的優(yōu)勢,提升模型整體效果。
本文提出了一種融合注意力與CorNet的多標簽文本分類模型MLACN。模型利用多層圖注意力網(wǎng)絡(GAT)通過標簽特征和標簽的鄰接矩陣構建標簽關聯(lián)圖,學習標簽之間的依賴關系,并設計“文本-標簽”注意力機制將標簽信息與文本上下文語義信息進行交互,獲得基于標簽語義信息的文本特征表示,并在標簽的預測層之后添加CorNet模塊學習標簽的相關性增強標簽預測。在3個標準多標簽文本分類的數(shù)據(jù)集上得到的實驗結果表明,所提出的方法其性能優(yōu)于當前先進的多標簽文本分類算法,驗證了MLACN模型的優(yōu)越性,同時,也驗證了引入圖注意力網(wǎng)絡和CorNet,建立具有文本語義聯(lián)系的標簽特征表示的有效性與合理性。
在接下來的工作中,將考慮如何處理大規(guī)模標簽數(shù)據(jù)集的多標簽分類問題,從更深層次挖掘語義聯(lián)系。同時,調整模型參數(shù)進一步優(yōu)化模型,降低訓練的時間復雜度,從而高效、準確地預測標簽。