王曉茹,張 珩
1.北京郵電大學(xué) 計(jì)算機(jī)學(xué)院,北京 100876
2.北京市網(wǎng)絡(luò)系統(tǒng)與網(wǎng)絡(luò)文化重點(diǎn)實(shí)驗(yàn)室,北京 100876
在過去的幾年,利用深度神經(jīng)網(wǎng)絡(luò),在很多圖像識別數(shù)據(jù)集上的準(zhǔn)確率已經(jīng)取得了很大地提高。這些模型[1-4]往往都基于卷積神經(jīng)網(wǎng)絡(luò),而且需要大量有標(biāo)簽的數(shù)據(jù)來訓(xùn)練。為了提高識別的準(zhǔn)確率,網(wǎng)絡(luò)向著更深和更復(fù)雜的方向發(fā)展,這無疑增加了網(wǎng)絡(luò)參數(shù),同時(shí)需要更多有標(biāo)簽的數(shù)據(jù)來訓(xùn)練。另一方面,基于傳統(tǒng)深度神經(jīng)網(wǎng)絡(luò)的圖像識別算法訓(xùn)練的模型往往只能識別訓(xùn)練數(shù)據(jù)中出現(xiàn)的類別,進(jìn)一步限制了算法的應(yīng)用和發(fā)展。不同于機(jī)器,人類通過幾次甚至一次觀察就能夠識別圖像的特征,當(dāng)再次看到相同類別的圖像時(shí)就能夠準(zhǔn)確地識別出來。希望機(jī)器也能夠擁有這樣的能力,所以在機(jī)器學(xué)習(xí)的領(lǐng)域中,一個(gè)新的方向逐漸受到大家的關(guān)注——小樣本學(xué)習(xí)[5-6]。
小樣本學(xué)習(xí)的目的是通過少數(shù)幾張甚至一張有標(biāo)簽圖像學(xué)習(xí)圖像的類別信息??梢酝ㄟ^對傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)進(jìn)行精細(xì)的調(diào)參來解決小樣本學(xué)習(xí)問題。但是由于有標(biāo)簽的訓(xùn)練數(shù)據(jù)不足,很容易導(dǎo)致網(wǎng)絡(luò)出現(xiàn)過擬合的現(xiàn)象。盡管能夠通過數(shù)據(jù)增強(qiáng)和正則化的技巧減輕過擬合現(xiàn)象,但是并不能從根本上解決這個(gè)問題。
通過元學(xué)習(xí)[7]能夠很好地解決小樣本學(xué)習(xí)問題,目前已經(jīng)有很多基于元學(xué)習(xí)思想的模型被提出。在元學(xué)習(xí)的訓(xùn)練階段,會將訓(xùn)練集分為一個(gè)個(gè)元任務(wù),然后用這一個(gè)個(gè)元任務(wù)去訓(xùn)練網(wǎng)絡(luò),即是指導(dǎo)網(wǎng)絡(luò)如何去解決一個(gè)個(gè)元任務(wù),而不是訓(xùn)練網(wǎng)絡(luò)識別所有類別的圖像。這樣,在測試的時(shí)候,即使測試數(shù)據(jù)集中有新的類別的圖像,網(wǎng)絡(luò)仍然有足夠的泛化能力識別出圖像的類別。基于元學(xué)習(xí)思想的小樣本學(xué)習(xí)算法主要包括基于度量的算法[8-14]、基于數(shù)據(jù)增強(qiáng)的算法[15-17]和基于模型優(yōu)化的算法[18-21]。其中基于度量的算法被認(rèn)為是一種最簡單而有效的方法。
基于度量的小樣本學(xué)習(xí)算法一般包含兩步:(1)通過一個(gè)特征提取器提取支撐集圖像和查詢集圖像的特征。(2)在一個(gè)統(tǒng)一的空間中比較特征的距離或者相似度,從而得到圖像所屬的類別。盡管有很多基于度量的小樣本學(xué)習(xí)算法被提出來,但是這些算法仍然很難達(dá)到很高的準(zhǔn)確率。認(rèn)為現(xiàn)有的基于度量的小樣本學(xué)習(xí)算法主要有兩個(gè)方面的不足:(1)特征提取網(wǎng)絡(luò)不能關(guān)注那些對分類起決定性作用的特征。(2)簡單的將圖像映射到同一個(gè)特征空間進(jìn)行比較,不能充分利用支撐集中不同類別圖像之間特征的差異。
基于上面的分析,主要關(guān)注兩個(gè)方面的問題:(1)如何更好地提取特征來指導(dǎo)后面的分類?(2)如何充分利用提取的特征以及特征之間的關(guān)系?因此,提出了基于注意力機(jī)制和圖卷積網(wǎng)絡(luò)的小樣本目標(biāo)識別算法。這篇文章的主要貢獻(xiàn)有:
(1)提出了添加了空間和通道注意力機(jī)制的特征提取網(wǎng)絡(luò),使得特征提取網(wǎng)絡(luò)能夠更好提取對分類有幫助的特征。
(2)提出了基于圖卷積網(wǎng)絡(luò)的關(guān)系網(wǎng)絡(luò),能夠在比較特征相似度的同時(shí)利用不同類別圖像特征之間的關(guān)系,從而提高分類的準(zhǔn)確率。
(3)在Omniglot數(shù)據(jù)集和miniImageNet數(shù)據(jù)集上進(jìn)行了大量的實(shí)驗(yàn),相較于其他的基于度量的小樣本學(xué)習(xí)算法,本文模型達(dá)到了更高的準(zhǔn)確率。
在這部分,將主要介紹基于度量的小樣本學(xué)習(xí)算法、注意力機(jī)制以及圖卷積網(wǎng)絡(luò)。
小樣本或者單樣本分類任務(wù)是一個(gè)很有應(yīng)用前景的研究方向。傳統(tǒng)的深度學(xué)習(xí)方法并不能很好的解決小樣本學(xué)習(xí)問題。最近,使用深度神經(jīng)網(wǎng)絡(luò)的方法已經(jīng)取得很好的效果。但是與傳統(tǒng)的分類任務(wù)相比,小樣本分類任務(wù)在分類的準(zhǔn)確率上還遠(yuǎn)遠(yuǎn)達(dá)不到同樣的水平。
已經(jīng)有很多研究表明使用元學(xué)習(xí)的方法能夠很好的解決小樣本分類任務(wù),其中基于度量的小樣本學(xué)習(xí)方法是一種簡單而又有效的方式。雙生網(wǎng)絡(luò)[8(]Siamese Network)是一種通過共享網(wǎng)絡(luò)權(quán)值實(shí)現(xiàn)的網(wǎng)絡(luò)。雙生網(wǎng)絡(luò)有兩個(gè)輸入,通過兩個(gè)共享權(quán)重的神經(jīng)網(wǎng)絡(luò)將輸入映射到新的空間,從而可以在這個(gè)新的空間計(jì)算兩個(gè)樣本的相似程度,最終達(dá)到分類的目的。匹配網(wǎng)絡(luò)[9](Matching Network)先使用一個(gè)卷積神經(jīng)網(wǎng)絡(luò)獲得支持集和測試集的淺層表示,然后將它們放入一個(gè)雙向LSTM 網(wǎng)絡(luò),最后通過計(jì)算輸出特征的余弦相似度來表示查詢集圖像與支撐集圖像的相似度。不需要對一個(gè)訓(xùn)練好的匹配網(wǎng)絡(luò)做任何改變,網(wǎng)絡(luò)就能夠識別訓(xùn)練過程中沒有遇見的類別。原型網(wǎng)絡(luò)[11(]Prototypical Network)是將輸入圖像映射到一個(gè)潛在空間。其中,一個(gè)類別的原形是對支撐集中的所有相同類別圖像的向量化樣例數(shù)據(jù)取均值得到的。然后再通過計(jì)算查詢集圖像的向量化值與類別原型之間的歐氏距離就能得到查詢集圖像的類別。即原型網(wǎng)絡(luò)認(rèn)為在映射后的空間中距離越近的樣例屬于同一類別的可能性越大。前面的提到的雙生網(wǎng)絡(luò)和原型網(wǎng)絡(luò)都是通過一個(gè)神經(jīng)網(wǎng)絡(luò)得到樣例的向量化表示,然后計(jì)算向量化表示之間的相似度或者距離來判斷樣例是不是同一類別。而關(guān)系網(wǎng)絡(luò)[12(]Relation Network)則是通過一個(gè)神經(jīng)網(wǎng)絡(luò)來計(jì)算不同樣例之間的距離。
注意力機(jī)制[22]最初在機(jī)器翻譯中被使用。隨后,注意力機(jī)制開始在計(jì)算機(jī)視覺任務(wù)中被使用。注意力機(jī)制和人類的視覺機(jī)制相似,由于人眼看到的圖像往往包含了大量的信息,但是大腦處理圖像的時(shí)候往往只會關(guān)注某些重要的部分,忽略那些不重要的信息,這樣能夠加快大腦處理圖像信息的速度。所以在計(jì)算機(jī)視覺中使用注意力機(jī)制可以讓神經(jīng)網(wǎng)絡(luò)關(guān)注與任務(wù)相關(guān)的信息。注意力機(jī)制可以簡單的理解為對某個(gè)時(shí)刻的輸出y,在輸入x上各部分施加不同的注意力的一種機(jī)制,而這里注意力就是權(quán)重。注意力機(jī)制可以分為軟注意力機(jī)制和強(qiáng)注意力機(jī)制[23]。軟注意力是可學(xué)習(xí)的,通常可以嵌入模型中直接訓(xùn)練,而強(qiáng)注意力是一個(gè)隨機(jī)預(yù)測的過程。目前,神經(jīng)網(wǎng)絡(luò)中使用較多的是軟注意力機(jī)制。
在過去的幾年中,神經(jīng)網(wǎng)絡(luò)的興起與應(yīng)用成功推動了模式識別和數(shù)據(jù)挖掘的研究。很多曾經(jīng)機(jī)器并不能很好的解決的問題,現(xiàn)在已經(jīng)能通過各種各樣的深度學(xué)習(xí)模型解決了。但是傳統(tǒng)的深度學(xué)習(xí)方法只能夠被應(yīng)用到提取歐式空間數(shù)據(jù)的特征上,許多從非歐式空間產(chǎn)生的數(shù)據(jù),傳統(tǒng)深度學(xué)習(xí)在上面的表現(xiàn)卻仍不盡如人意。因而人們開始設(shè)計(jì)能夠處理非歐式結(jié)構(gòu)數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò),即圖神經(jīng)網(wǎng)絡(luò)。在傳統(tǒng)深度學(xué)習(xí)中,數(shù)據(jù)樣本之間往往被認(rèn)為是獨(dú)立的,但是在圖神經(jīng)網(wǎng)絡(luò)中,每個(gè)樣本結(jié)點(diǎn)都會通過邊與其他數(shù)據(jù)樣本建立聯(lián)系,這些信息能夠用來捕獲不同實(shí)例之間的相互依賴關(guān)系。圖神經(jīng)網(wǎng)絡(luò)包含了很多類別[24],在這篇論文中,使用了圖卷積網(wǎng)絡(luò)[25-26],圖卷積網(wǎng)絡(luò)是對圖結(jié)構(gòu)的數(shù)據(jù)進(jìn)行操作的卷積神經(jīng)網(wǎng)絡(luò)。
在這部分,首先將定義小樣本學(xué)習(xí)中的專業(yè)術(shù)語和相關(guān)的符號,然后介紹使用的數(shù)據(jù)集,最后將提出基于注意力機(jī)制和圖卷積網(wǎng)絡(luò)的小樣本目標(biāo)識別算法,并且介紹模型設(shè)計(jì)的細(xì)節(jié)。
對于人來說,初識一個(gè)新的物品,人類可以通過探索很快地了解并熟悉它,而這種學(xué)習(xí)能力,是目前機(jī)器所沒有的。如果機(jī)器也能擁有這種學(xué)習(xí)能力,面對樣本量較少的問題時(shí),便可以快速地學(xué)習(xí),這便是元學(xué)習(xí)。已經(jīng)有很多研究發(fā)現(xiàn),通過元學(xué)習(xí)可以很好的解決小樣本分類問題。
基于元學(xué)習(xí)的小樣本分類方法將分類任務(wù)分為一個(gè)一個(gè)的元任務(wù)。通常,元學(xué)習(xí)將訓(xùn)練集分為訓(xùn)練任務(wù)集和測試任務(wù)集。在訓(xùn)練的過程中,隨機(jī)的從訓(xùn)練集中抽取C×K個(gè)樣本作為支撐集(support set),其中C表示類別數(shù),K表示每一類的圖像數(shù)。然后再從訓(xùn)練集中剩下的類別為C中某一類的圖像中抽取一定數(shù)量的圖像作為查詢集(query set),即那個(gè)由查詢集和支撐集圖像構(gòu)成的一組數(shù)據(jù)成為一個(gè)episode。小樣本分類任務(wù)的目是要得到查詢集中的圖像分別屬于C類中的哪一類。當(dāng)每次抽取的類別數(shù)為C,每類圖像的數(shù)量為K時(shí),把這樣的任務(wù)稱為C-way K-shot問題。式(1)和式(2)給出了支撐集和查詢集的定義。
其中S表示支撐集,xi和yi表示支撐集中的第i張圖像以及其對應(yīng)的標(biāo)簽。
其中Q表示查詢集,N表示查詢集中的圖像數(shù)量,xi和yi分別表示查詢集中的第i張圖像以及其對應(yīng)的標(biāo)簽。特別的,當(dāng)支撐集中每類圖像的數(shù)量K=1 時(shí),稱這類任務(wù)為單樣本學(xué)習(xí);當(dāng)K>1 時(shí),稱為小樣本學(xué)習(xí)。
本文提出了一種基于注意力機(jī)制和圖卷積網(wǎng)絡(luò)的端到端的模型來解決小樣本分類問題。模型的框架如圖1所示。
從圖1 可以看到,本文模型包含兩個(gè)網(wǎng)絡(luò):基于注意力機(jī)制的特征提取網(wǎng)絡(luò)(FN)和基于圖卷積網(wǎng)絡(luò)的關(guān)系網(wǎng)絡(luò)(RN)。特征提取網(wǎng)絡(luò)用于提取輸入圖像的高維特征表示,而關(guān)系網(wǎng)絡(luò)根則根據(jù)兩張圖像的特征表示判斷兩張圖像是不是同一類,從而得到查詢集圖像的所屬的類別。
圖1 網(wǎng)絡(luò)結(jié)構(gòu)Fig.1 Network architecture
以5-way1-shot問題為例,從數(shù)據(jù)集中隨機(jī)抽取5 張類別不同的圖像{x1,x2,x3,x4,x5}組成支撐集,然后從數(shù)據(jù)集剩余的類別與支撐集類別相同的圖像中隨機(jī)抽取一張圖像作為查詢集。這里目標(biāo)是判斷圖像xˉ和支撐集{x1,x2,x3,x4,x5}中哪一張圖像屬于同一類別。
首先,將所有圖像都輸入到特征提取網(wǎng)絡(luò),得到每張圖像的特征然后,將支撐集圖像特征和查詢集圖像特征拼接在一起,得到其中C(·,·) 是特征的拼接操作。將拼接后的特征送入關(guān)系網(wǎng)絡(luò),最后關(guān)系網(wǎng)絡(luò)的輸出一個(gè)5 維的向量,表示兩個(gè)特征對應(yīng)圖像屬于同一類別的概率。
2.3.1 特征提取網(wǎng)絡(luò)
特征提取網(wǎng)絡(luò)使用了一個(gè)包含四層卷積的CNN網(wǎng)絡(luò)。為了更好地提取任務(wù)相關(guān)的特征,在特征提取網(wǎng)絡(luò)中添加了注意力模塊。比較了一些注意力模塊,如Nonlocal[27]等,但是Non-local 的計(jì)算量較大,而且特征提取網(wǎng)絡(luò)的深度很小,所以選擇了即插即用的卷積注意力模塊[28(]Convolutional Block Attention Module,CBAM)。CBAM 包含了兩個(gè)維度的注意力——空間注意力和通道注意力。
特征的每一個(gè)通道都表示一個(gè)專門的檢測器,可以認(rèn)為每一個(gè)通道代表著一種不同的特征。因此使用通道注意力可以使得網(wǎng)絡(luò)知道關(guān)注什么特征針對目前的任務(wù)是有意義的。如圖2(a)所示,先在每一個(gè)通道的特征圖上進(jìn)行全局平均池化和全局最大池化,得到兩個(gè)1×1×C的通道描述。然后將這兩個(gè)通道描述送入一個(gè)多層感知機(jī)中,這個(gè)多層感知機(jī)由兩個(gè)共享權(quán)重的全連接層組成,即將全局平均池化和全局最大池化的結(jié)果都通過這個(gè)全連接層得到相應(yīng)的輸出,這兩層全連接層的神經(jīng)元個(gè)數(shù)分別為和C,在實(shí)驗(yàn)中,取r=2。然后將輸出的特征經(jīng)過逐元素相加后經(jīng)過Sigmoid激活函數(shù)得到最終的通道注意力。最后,將得到的通道注意力與原來的特征圖相乘。
而空間注意力能夠關(guān)注哪里來的特征是有意義的。如圖2(b)所示,與通道注意力相似,在通道維度上進(jìn)行最大池化和平局池化,得到兩個(gè)H×W×1 的空間描述。然后,將這兩個(gè)特征在通道維度拼接,經(jīng)過一個(gè)卷積層后得到空間注意力。最后將經(jīng)過了通道注意力的特征圖與空間注意力相乘,最終得到了經(jīng)過調(diào)整的特征圖。
圖2 卷積注意力模塊Fig.2 Convolutional block attention module
2.3.2 基于圖卷積的關(guān)系網(wǎng)絡(luò)
關(guān)系網(wǎng)絡(luò)的目的是比較支撐集圖像特征和查詢集圖像特征,從而得到查詢集圖像屬于支撐集圖像類別中的哪一類?;趥鹘y(tǒng)神經(jīng)網(wǎng)絡(luò)的關(guān)系網(wǎng)絡(luò)僅僅將查詢集圖像特征和一類支撐集圖像特作為輸入,所以不能充分利用支撐集中不同類別圖像之間的差別和相似性作出更準(zhǔn)確的判斷。所以,針對小樣本學(xué)習(xí)的訓(xùn)練特點(diǎn),將不同的圖像特征抽象成圖卷積中的節(jié)點(diǎn)特征,同時(shí)將節(jié)點(diǎn)的距離作為邊的權(quán)重。這樣關(guān)系網(wǎng)絡(luò)就能夠利用圖卷積里面的消息傳遞的特性獲得整個(gè)支撐集的信息,從而作出更準(zhǔn)確的判斷。
在前面的特征提取網(wǎng)路中,得到了支撐集圖像特征和查詢集圖像特征,并把它們拼接在一起得到了在特征提取網(wǎng)絡(luò)中,在C-way Kshot,K>0 的時(shí)候,將同一類別圖像特征的平均值作為該類的特征,即拼接后特征先經(jīng)過一個(gè)全連接層,得到融合后的特征。然后,將這些特征作為圖的結(jié)點(diǎn),支撐集圖像的特征之間的距離作為邊的權(quán)重,組成一個(gè)完全圖。距離的計(jì)算方式如下:
這個(gè)完全圖經(jīng)過圖卷積網(wǎng)絡(luò)后,每個(gè)節(jié)點(diǎn)輸出一個(gè)特征。將輸出的特征經(jīng)過Softmax激活函數(shù)得到查詢集圖像屬于每一類的概率。
2.3.3 損失函數(shù)
在模型的訓(xùn)練過程中,使用了交叉熵?fù)p失函數(shù),即:
φ、θ分別為特征提取網(wǎng)絡(luò)和關(guān)系網(wǎng)絡(luò)的參數(shù),ri,j表示第i組數(shù)據(jù)中查詢集圖像與第j張支撐集圖像為同一類的概率。
在這部分,將介紹實(shí)驗(yàn)環(huán)境設(shè)置和使用的數(shù)據(jù)集,然后將介紹網(wǎng)絡(luò)的具體結(jié)構(gòu)。最后,進(jìn)行了大量的實(shí)驗(yàn),驗(yàn)證了本文模型在小樣本分類問題中能夠取得的效果。同時(shí),進(jìn)行了消融實(shí)驗(yàn)驗(yàn)證模型中相關(guān)模塊的有效性。
在Omniglot 數(shù)據(jù)集和miniImageNet 數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn)。Omniglot數(shù)據(jù)集包含來自50個(gè)不同字母表的1 623 個(gè)不同手寫字符,每個(gè)字符包含20 張28×28 圖片。使用了與文獻(xiàn)[9,11-12]類似的處理方式,將圖像分別旋轉(zhuǎn)90°、180°和270°作為新的類別。然后將1 200個(gè)字符以及它們旋轉(zhuǎn)之后的字符作為訓(xùn)練集,剩下的423個(gè)字符以及它們旋轉(zhuǎn)之后的字符作為測試集。與文獻(xiàn)[9-12,18]類似,在Omniglot 數(shù)據(jù)集上進(jìn)行了20-way1-shot和20-way5-shot的實(shí)驗(yàn)。
miniImageNet 數(shù)據(jù)集是ImageNet 數(shù)據(jù)集的子集。miniImageNet 數(shù)據(jù)集包含了100 類圖像,每一類由600張圖像組成,將圖像統(tǒng)一處理成84×84 的大小。和文獻(xiàn)[12]中一樣,將100 類中的64 類作為訓(xùn)練集,16 類作為驗(yàn)證集,20 類作為測試集。與文獻(xiàn)[9-12,18]類似,在miniImageNet 數(shù)據(jù)集上進(jìn)行了5-way1-shot和5-way5-shot的實(shí)驗(yàn)。兩個(gè)數(shù)據(jù)集的實(shí)驗(yàn)設(shè)置如表1所示。
表1 在Omniglot和miniImageNet數(shù)據(jù)集上的試驗(yàn)設(shè)置Table 1 Experiment settings on Omniglot and miniimageNet dataset
本文的模型一共包含了兩個(gè)部分:特征提取網(wǎng)絡(luò)和基于圖卷積的關(guān)系網(wǎng)絡(luò)。
特征提取網(wǎng)絡(luò)由4個(gè)卷積層組成,在特征提取網(wǎng)絡(luò)里面使用的通道注意力和空間注意力。特征提取網(wǎng)絡(luò)的具體結(jié)構(gòu)如圖3。在每一個(gè)卷積操作之后添加了一個(gè)注意力模塊。
圖3 特征提取網(wǎng)絡(luò)的結(jié)構(gòu)Fig.3 Architecture of feature extraction network
在關(guān)系網(wǎng)絡(luò)中,首先將查詢集圖像特征和支撐集圖像特征經(jīng)過一個(gè)全連接層進(jìn)行融合,然后將融合后的特征作為圖的結(jié)點(diǎn)特征,由這些結(jié)點(diǎn)構(gòu)成一個(gè)完全圖,然后將這個(gè)完全圖進(jìn)行3 層圖卷積,最后輸出一個(gè)5 維的向量,經(jīng)過Softmax 后得到查詢集圖像屬于每一類的概率。特征提取網(wǎng)絡(luò)結(jié)構(gòu)如圖4所示。
圖4 關(guān)系網(wǎng)絡(luò)的結(jié)構(gòu)Fig.4 Architecture of relation network
3.3.1 對比實(shí)驗(yàn)
將本文模型與多個(gè)小樣本分類模型進(jìn)行了比較,實(shí)驗(yàn)結(jié)果見表2和表3。
表2 在Omniglot數(shù)據(jù)集上的效果Table 2 Results on Omniglot dataset %
表3 在miniImageNet數(shù)據(jù)集上的效果Table 3 Results on miniImageNet dataset %
從實(shí)驗(yàn)數(shù)據(jù)可以看到,在兩個(gè)基準(zhǔn)數(shù)據(jù)集上面,本文模型都取得超過基準(zhǔn)模型的效果,雖然在5-way5-shot任務(wù)中的準(zhǔn)確率不及原型網(wǎng)絡(luò),但是已經(jīng)達(dá)到了相近的水平。在特征提取網(wǎng)絡(luò)中添加了注意力機(jī)制,使得特征提取網(wǎng)絡(luò)能夠關(guān)注那些對分類更有幫助的特征。同時(shí),基于圖卷積的關(guān)系網(wǎng)絡(luò)改進(jìn)了傳統(tǒng)關(guān)系網(wǎng)絡(luò)不能利用支撐集圖像特征之間關(guān)系的問題,使得關(guān)系網(wǎng)絡(luò)不僅能夠比較支撐集圖像特征和查詢集圖像特征,而且能夠利用圖卷積的信息傳遞得到支撐集中其他圖像的相關(guān)信息。因此,本文模型能夠取得比基準(zhǔn)模型更高的準(zhǔn)確率。
3.3.2 消融實(shí)驗(yàn)
為了能夠分析清楚網(wǎng)絡(luò)中各個(gè)部分的影響,使用控制變量法對模型進(jìn)行了消融實(shí)驗(yàn),分別驗(yàn)證了注意力機(jī)制和圖卷積的有效性。消融實(shí)驗(yàn)的結(jié)果見表4。
表4 在miniImageNet數(shù)據(jù)集上的消融實(shí)驗(yàn)Table 4 Ablation experiment on miniImageNet dataset %
如表4所示,進(jìn)行了3組實(shí)驗(yàn):(A)只在特征提取網(wǎng)中使用注意力機(jī)制;(B)只在關(guān)系網(wǎng)絡(luò)中使用圖卷積;(A+B)同時(shí)使用注意力機(jī)制和圖卷積。從實(shí)驗(yàn)結(jié)果可以看出針對關(guān)系網(wǎng)絡(luò)提出的兩個(gè)改進(jìn)點(diǎn)都能提高分類的準(zhǔn)確率。在與原型網(wǎng)絡(luò)對比中可以看到,雖然單獨(dú)的使用注意模塊和圖卷積并不能獲得更好的效果,但是,注意力模塊和圖卷積的組合卻能夠明顯提高分類的準(zhǔn)確率。使用注意模塊CBAM能夠加強(qiáng)網(wǎng)絡(luò)對關(guān)鍵特征的提取。通過調(diào)整不同特征的權(quán)重,CBAM能夠使得特征提取網(wǎng)絡(luò)提取到對分類更有效的特征。但是由于特征提取網(wǎng)絡(luò)并不復(fù)雜,只能在空間和通道維度上調(diào)整不同特征的權(quán)重,所以使用CBAM 對最終結(jié)果的提升效果不及使用圖卷積的效果。使用圖卷積不僅能夠讓關(guān)系網(wǎng)絡(luò)使用當(dāng)前圖像和查詢集圖像特征,而且能夠讓關(guān)系網(wǎng)絡(luò)綜合支撐集中各個(gè)圖像特征之間的差異,利用這些差異,關(guān)系網(wǎng)絡(luò)能夠做出更準(zhǔn)確的判斷。同時(shí),發(fā)現(xiàn)兩種方式一起能夠明顯提高傳統(tǒng)關(guān)系網(wǎng)絡(luò)分類的準(zhǔn)確率。
本文提出了基于注意力機(jī)制和圖卷積的解決小樣本分類問題的網(wǎng)絡(luò)。本文模型主要由特征提取網(wǎng)絡(luò)和關(guān)系網(wǎng)絡(luò)組成。在特征提取網(wǎng)絡(luò)中,使用了通道和空間注意力來指導(dǎo)神經(jīng)網(wǎng)絡(luò)提取更重要的特征。在關(guān)系網(wǎng)絡(luò)中,使用了圖卷積讓網(wǎng)絡(luò)在比較查詢集圖像特征和對應(yīng)的支撐集圖像特征的同時(shí)利用圖卷積的消息傳遞獲取支撐集中其他圖像的特征信息。對比實(shí)驗(yàn)表明,在Omnislot 數(shù)據(jù)集和miniImageNet 數(shù)據(jù)集上均取得了比基準(zhǔn)模型更好的效果。消融實(shí)驗(yàn)結(jié)果說明,注意力模塊和圖卷積的使用確實(shí)提高了模型分類的準(zhǔn)確率。