董博文,汪榮貴,楊 娟,薛麗霞
合肥工業(yè)大學(xué) 計(jì)算機(jī)與信息學(xué)院,合肥 230601
近年來深度學(xué)習(xí)取得飛躍性進(jìn)展,在計(jì)算機(jī)視覺方向,如語義分割[1]、目標(biāo)檢測[2-3]、圖像分類[4]等領(lǐng)域,以及自然語言處理方向[5-6]的研究中的表現(xiàn)越來越好,其中以ResNet[7]為代表的一些深度網(wǎng)絡(luò)在圖像分類任務(wù)中的準(zhǔn)確率甚至超過了人類。但是,這些網(wǎng)絡(luò)達(dá)到這樣的準(zhǔn)確率所需要的訓(xùn)練樣本數(shù)量是非常龐大的,而在一些情況下,人們無法獲得大量樣本,或者獲得所需樣本的代價(jià)過高,這時(shí)就要求深度網(wǎng)絡(luò)能夠通過少量樣本的學(xué)習(xí)實(shí)現(xiàn)較好的分類能力。雖然對于人們來說這一點(diǎn)很容易實(shí)現(xiàn),但對深度網(wǎng)絡(luò)來說很難。主要原因是,首先深度網(wǎng)絡(luò)參數(shù)量過大,訓(xùn)練這些參數(shù)理所應(yīng)當(dāng)?shù)匦枰銐驍?shù)量的樣本;其次,深度網(wǎng)絡(luò)模型結(jié)構(gòu)復(fù)雜,對已有樣本有著很強(qiáng)的表達(dá)能力,但是對未知樣本的表達(dá)能力不足,因此需要有大量可用于學(xué)習(xí)的樣本,盡量使這些樣本的分布覆蓋整體樣本的分布,以避免深度網(wǎng)絡(luò)在訓(xùn)練集上效果好、測試集上效果差的現(xiàn)象,即過擬合現(xiàn)象。對此,一些學(xué)者開始研究如何讓機(jī)器實(shí)現(xiàn)人類的這種少樣本學(xué)習(xí)能力,即“小樣本學(xué)習(xí)問題”[8-9]。
小樣本學(xué)習(xí)的研究中主要使用遷移學(xué)習(xí)的思想,即先使用相似任務(wù)的大量樣本預(yù)訓(xùn)練深度網(wǎng)絡(luò),以此模擬人類積累經(jīng)驗(yàn)的過程,然后利用得到的少量的當(dāng)前任務(wù)樣本進(jìn)行網(wǎng)絡(luò)模型參數(shù)的微調(diào)。遷移學(xué)習(xí)的方法雖然提高了網(wǎng)絡(luò)模型的泛化能力,但是預(yù)訓(xùn)練樣本與當(dāng)前任務(wù)的樣本分布可能差別較大,因此在研究中需要盡可能提高模型快速學(xué)習(xí)能力。據(jù)此遷移學(xué)習(xí)的思想又出現(xiàn)了幾類主要方法:數(shù)據(jù)增強(qiáng)、注意力機(jī)制、元學(xué)習(xí)、度量學(xué)習(xí)。對于當(dāng)前任務(wù)樣本過少的問題,數(shù)據(jù)增強(qiáng)[10-11]是最直觀的解決方法,但是虛擬樣本的生成并不能覆蓋真正的樣本空間,因此數(shù)據(jù)增強(qiáng)的方法只能在一定程度上提升模型效果;注意力機(jī)制[12-13]則是讓網(wǎng)絡(luò)模型將學(xué)習(xí)的重心放在樣本更重要的區(qū)域,提高了網(wǎng)絡(luò)模型學(xué)習(xí)樣本的效率;元學(xué)習(xí)[14-15]是一種讓機(jī)器模仿人類根據(jù)已有經(jīng)驗(yàn)進(jìn)行快速學(xué)習(xí)的策略,這種學(xué)習(xí)策略很適合在小樣本的情況下有效學(xué)習(xí);而度量學(xué)習(xí)[16]的方法簡單有效,其旨在找到一個(gè)適用于具體任務(wù)的距離度量方法,使相似樣本的距離更近。
度量學(xué)習(xí)已經(jīng)在小樣本研究中取得了很好的效果,但是傳統(tǒng)度量模型[17-18]利用神經(jīng)網(wǎng)絡(luò)進(jìn)行分類時(shí),存在兩個(gè)問題:首先,用于分類的深度網(wǎng)絡(luò)僅使用頂層特征進(jìn)行度量學(xué)習(xí),從特征提取的角度來看,頂層樣本特征分辨率低,學(xué)習(xí)的更多是語義特征,而忽略了樣本很多細(xì)節(jié)信息[19]。其次,在獲得樣本特征后,傳統(tǒng)度量學(xué)習(xí)方法在求解每個(gè)類表達(dá)的過程中并未考慮到訓(xùn)練集中樣本類間與類內(nèi)的信息關(guān)聯(lián)。針對以上兩個(gè)問題,基于度量學(xué)習(xí)思想,本文提出以下創(chuàng)新點(diǎn):
(1)利用多尺度特征[20]可以有效緩解單一尺度特征信息片面化的問題,基于多尺度特征的目標(biāo)識別是計(jì)算機(jī)視覺領(lǐng)域的一個(gè)基本挑戰(zhàn),其應(yīng)用可以有效提升智能體目標(biāo)識別能力。因此本文設(shè)計(jì)了包含卷積與全局平均池化以及跳躍連接的最小殘差神經(jīng)網(wǎng)絡(luò)塊,并基于最小殘差神經(jīng)網(wǎng)絡(luò)塊設(shè)計(jì)跨尺度連接的多尺度特征提取器結(jié)構(gòu),使得提取到的特征有豐富語義信息,且減少隨卷積網(wǎng)絡(luò)深度增加而丟失的特征細(xì)節(jié)信息。
(2)圖神經(jīng)網(wǎng)絡(luò)可以充分挖掘數(shù)據(jù)之間的豐富關(guān)系,并且圖結(jié)構(gòu)可以很容易對樣本數(shù)據(jù)進(jìn)行聚類,進(jìn)而使數(shù)據(jù)簇更易于分類。本文提出了一種掩碼圖模型,通過元學(xué)習(xí)策略生成掩碼,在每次結(jié)點(diǎn)更新的過程中從相鄰結(jié)點(diǎn)中屏蔽掉不利于更新的結(jié)點(diǎn);此外,本文圖模型信息傳播過程使用一種更有效的點(diǎn)乘注意力機(jī)制[21]而非使用帶有注意力的L1距離度量。
(3)在使用融合的多尺度特征計(jì)算原型過程中,提出了特征貢獻(xiàn)度,反映特征在嵌入空間分布與該類原型之間的位置關(guān)系,并提出了一種互斥損失,這兩個(gè)創(chuàng)新促使模型生成更靠近真實(shí)分布中心的原型。
小樣本學(xué)習(xí)是由Li等人[9]在2006年首次提出,研究如何使模型完成一個(gè)新任務(wù)而僅使用極少量訓(xùn)練樣本。從仿生學(xué)的角度,小樣本學(xué)習(xí)發(fā)展主要模仿人類快速學(xué)習(xí)的過程,即遷移學(xué)習(xí)的過程。Zamir等人[22]和Yu等人[23]通過多個(gè)任務(wù)之間相互遷移學(xué)習(xí),得到多任務(wù)之間的遷移效率矩陣,以此來進(jìn)行任務(wù)相似性判斷以及尋求更高效的遷移。但是,即使是相似任務(wù)的高效遷移,其效果也并不能滿足人們對分類的要求。因此,遷移學(xué)習(xí)又衍生出幾類提升遷移效果的方法,主要有數(shù)據(jù)增強(qiáng)、元學(xué)習(xí)、度量學(xué)習(xí)。
數(shù)據(jù)增強(qiáng)的方法旨在增加小樣本任務(wù)的樣本數(shù)量,以平衡任務(wù)樣本數(shù)量與預(yù)訓(xùn)練樣本數(shù)量。數(shù)據(jù)增強(qiáng)的實(shí)現(xiàn)中,Tremblay 等人[10]基于域隨機(jī)化理論,使用專業(yè)的三維軟件,通過在軟件中調(diào)整虛擬目標(biāo)的角度、光照、紋理等,以及改變目標(biāo)在背景中的位置來生成需要的樣本。該方法雖然可以生成大量樣本,但需提前獲取目標(biāo)三維模型,實(shí)現(xiàn)過程復(fù)雜,故只能應(yīng)用于特定任務(wù)。而Hariharan 等人[11]提出的模型通過模仿訓(xùn)練樣本之間的映射關(guān)系生成新的樣本,不需獲取樣本目標(biāo)的額外信息,適用性更廣。有一些數(shù)據(jù)增強(qiáng)的方法則結(jié)合語義信息,從另一個(gè)角度擴(kuò)充了樣本。Chen等人[24]使用編碼器將樣本映射到語義空間,在語義空間中分別根據(jù)兩種語義分布(語義高斯和語義近鄰)找到相近的語義,并將其通過解碼器轉(zhuǎn)回到圖片空間從而進(jìn)行數(shù)據(jù)增強(qiáng)。同樣是利用語義信息進(jìn)行數(shù)據(jù)增強(qiáng),Alfassy等人[25]搭建用于對圖片進(jìn)行交、并、差操作的神經(jīng)網(wǎng)絡(luò),圖片的交、并、差依據(jù)圖片語義所含元素,在進(jìn)行端到端的訓(xùn)練后可以利用該網(wǎng)絡(luò)生成新樣本。數(shù)據(jù)增強(qiáng)需要依賴諸如目標(biāo)三維模型、訓(xùn)練樣本之間的關(guān)系以及語義等信息,但是利用這些信息進(jìn)行的數(shù)據(jù)增強(qiáng),只是對真實(shí)樣本的模仿,這種模仿不可能實(shí)現(xiàn)無偏差,因此只能盡力而為。
元學(xué)習(xí)的方法是在多個(gè)任務(wù)之上進(jìn)行模型的訓(xùn)練,學(xué)習(xí)任務(wù)之間的共性,以增強(qiáng)模型泛化能力,使模型在訓(xùn)練數(shù)據(jù)不充足的情況下提高性能。這種思想一般是通過設(shè)置元學(xué)習(xí)器、基礎(chǔ)學(xué)習(xí)器實(shí)現(xiàn),元學(xué)習(xí)器用來積累模型執(zhí)行的多任務(wù)之間的共性,而基礎(chǔ)學(xué)習(xí)器則聚焦于模型處理單一任務(wù)的性能。Munkhdalai 等人[14]和Wang 等人[26]使用CNN(convolutional neural network)作為元學(xué)習(xí)器和基礎(chǔ)學(xué)習(xí)器并構(gòu)建出元學(xué)習(xí)模型,而Ravi 等人[27]則使用了長短時(shí)記憶網(wǎng)絡(luò)(long short-term memory network,LSTM)作為元學(xué)習(xí)器,其中LSTM 的細(xì)胞狀態(tài)為元學(xué)習(xí)器的參數(shù)。更簡單一點(diǎn)的,Sun 等人[28]和Keshari 等人[29]設(shè)置用于放縮和偏移基礎(chǔ)學(xué)習(xí)器卷積核參數(shù)的參數(shù)作為元學(xué)習(xí)器參數(shù)。還有些元學(xué)習(xí)的實(shí)現(xiàn)偏重于模型快速適應(yīng)能力,典型的模型為MAML(model-agnostic meta-learning)[15],它在訓(xùn)練時(shí)將每個(gè)任務(wù)對模型初始參數(shù)的優(yōu)化結(jié)果通過梯度求和的形式綜合在一起,進(jìn)行梯度的反向傳播,使模型有很好的泛化能力,從而在新任務(wù)到來時(shí)僅進(jìn)行少量樣本的學(xué)習(xí)就可以達(dá)到較好的效果。Boney 等人[30]將MAML 算法應(yīng)用到半監(jiān)督任務(wù)中,也取得了不錯(cuò)的效果。MAML 雖然有很強(qiáng)的泛化能力,但在一些任務(wù)中,元學(xué)習(xí)階段會出現(xiàn)過擬合現(xiàn)象。Jamal等人[31]提出的模型擴(kuò)展了MAML,提出了兩種新范式避免模型元學(xué)習(xí)階段訓(xùn)練過擬合,同時(shí)提升模型的泛化能力。元學(xué)習(xí)能夠有效進(jìn)行“經(jīng)驗(yàn)和知識”的積累,并指導(dǎo)模型對任務(wù)進(jìn)行快速學(xué)習(xí),但元學(xué)習(xí)器的設(shè)置會增加模型復(fù)雜度,因此本文設(shè)計(jì)了一個(gè)元學(xué)習(xí)器僅用于生成掩碼,而使用更加簡單有效的度量學(xué)習(xí)方法作為基礎(chǔ)學(xué)習(xí)器。
度量學(xué)習(xí)的方法是模型將樣本映射到特征空間,并進(jìn)行相似性度量,以找到和測試樣本最相似的標(biāo)注樣本,從而實(shí)現(xiàn)分類。其中度量方法的選取有兩種情況:(1)使用傳統(tǒng)固定的距離度量方法,如歐式距離度量、余弦距離度量等。Koch等人[16]提出的孿生網(wǎng)絡(luò)和Snell等人[18]提出的原型網(wǎng)絡(luò)(prototypical networks,PN)分別使用了L1距離與歐式距離,Vinyals 等人[17]提出的匹配網(wǎng)絡(luò)(matching networks,MN)則使用了余弦距離作為度量方法,這些模型中由于度量方法是固定的,他們將研究的重心放在如何更好地獲得用于度量的特征向量。孿生網(wǎng)絡(luò)設(shè)計(jì)了一種對稱的網(wǎng)絡(luò)結(jié)構(gòu),將要比較的兩個(gè)樣本分別輸入到這個(gè)網(wǎng)絡(luò)對稱的兩部分中,在網(wǎng)絡(luò)的輸出端將兩部分提取到的特征進(jìn)行L1距離度量,得到兩個(gè)樣本屬于同一類的概率;MN則將提取到的支持集特征輸入到一個(gè)雙向LSTM 中,整個(gè)支持集作為上下文,以消除每個(gè)任務(wù)隨機(jī)選擇支持集而產(chǎn)生的差異性。而PN是通過找出每個(gè)類在特征空間中的原型即類在嵌入空間中的特征表達(dá),用于度量。(2)使用參數(shù)可學(xué)習(xí)的度量方法。這類方法過去如Xing 等人[32]的研究一樣通過在度量函數(shù)中設(shè)置可學(xué)習(xí)參數(shù)而實(shí)現(xiàn),而現(xiàn)在更多的是搭建專門用于距離度量的神經(jīng)網(wǎng)絡(luò),如Sung 等人提出的關(guān)系網(wǎng)絡(luò)[33]等。度量學(xué)習(xí)模型在設(shè)計(jì)時(shí)原理清晰,結(jié)構(gòu)相對簡單,同時(shí),通過尋找相似樣本而實(shí)現(xiàn)歸類的思想使少樣本學(xué)習(xí)更有效。而最近,很多度量學(xué)習(xí)的方法,通過圖神經(jīng)網(wǎng)絡(luò)(graph neural network,GNN)[34-35]來組織和挖掘樣本關(guān)系并用于距離度量,這些研究取得了不錯(cuò)的效果。但是這些傳統(tǒng)圖模型在結(jié)點(diǎn)更新時(shí)使用無差別的更新策略,會造成無用信息傳播,干擾分類。針對這個(gè)問題,本文設(shè)計(jì)一個(gè)包含掩碼的新的GNN網(wǎng)絡(luò),通過掩碼篩選邊來指導(dǎo)圖中結(jié)點(diǎn)更新,實(shí)現(xiàn)特征更好的信息交互,從而有更好的分類效果。
這部分首先介紹小樣本學(xué)習(xí)的問題定義,然后介紹本文方法的整體架構(gòu)以及詳細(xì)的實(shí)現(xiàn)過程。
小樣本學(xué)習(xí)問題在計(jì)算機(jī)視覺領(lǐng)域的任務(wù)T中,一般將數(shù)據(jù)集劃分為訓(xùn)練集Tra、支持集Sup以及查詢集Que。其中訓(xùn)練集來自單獨(dú)的樣本空間,與支持集和查詢集樣本類別互斥,用于訓(xùn)練階段預(yù)訓(xùn)練網(wǎng)絡(luò)模型。支持集與查詢集的樣本類別完全相同,但是樣本互斥,其中支持集只有少量樣本,用來在測試階段訓(xùn)練網(wǎng)絡(luò)模型,查詢集用來測試模型使用支持集訓(xùn)練后對其中新類別的識別準(zhǔn)確率。由于支持集只含有少量樣本且樣本類別未出現(xiàn)在訓(xùn)練集中,這樣就可以檢測模型在少樣本情況下的學(xué)習(xí)能力。一般,如果支持集中有N類的樣本,且每類有K張樣本圖片,則稱這個(gè)小樣本任務(wù)為“N-wayK-shot”任務(wù)。
Vinyals 等人[17]提出了周期性的策略以在訓(xùn)練階段模擬小樣本任務(wù)的設(shè)定,這種訓(xùn)練策略由于融入元學(xué)習(xí)思想,在小樣本分類任務(wù)中十分有效,也因此被廣泛使用。具體的,如果小樣本任務(wù)為N-wayK-shot任務(wù),則在訓(xùn)練階段的每個(gè)周期,從訓(xùn)練集中隨機(jī)選擇N個(gè)類別的樣本,并從這N個(gè)類別樣本每一類中隨機(jī)挑選出K個(gè)訓(xùn)練樣本模擬支持集;再從這N個(gè)類別剩下的樣本中隨機(jī)挑選C個(gè)樣本作為查詢集,則有。訓(xùn)練階段使用這樣模擬測試階段的數(shù)據(jù)設(shè)定進(jìn)行周期性迭代訓(xùn)練,直到收斂。
本文方法主要分為以下幾部分:(1)用于提取多尺度特征的,基于最小殘差神經(jīng)網(wǎng)絡(luò)塊與卷積塊的多尺度特征提取器;(2)用于增強(qiáng)多尺度特征的掩碼圖網(wǎng)絡(luò);(3)樣本分類及損失函數(shù)部分。如圖1所示,為本文方法在小樣本學(xué)習(xí)“5-way 1-shot”分類問題上的整體流程。
圖1 本文方法的整體流程Fig.1 Overall framework of proposed method
匹配網(wǎng)絡(luò)[17]和原型網(wǎng)絡(luò)[18]等小樣本學(xué)習(xí)的經(jīng)典網(wǎng)絡(luò)模型采用由4個(gè)卷積塊組成的四層卷積網(wǎng)絡(luò)(ConvNet)提取特征,但是單一尺度的特征對樣本信息利用不充分[20],本文基于四層卷積神經(jīng)網(wǎng)絡(luò),設(shè)計(jì)多尺度特征提取器。如圖2所示,多尺度特征提取器共有3個(gè)分支,每個(gè)分支前半部分為卷積塊組成的原始特征編碼器,卷積塊的結(jié)構(gòu)如圖3(a)所示;后半部分為最小殘差神經(jīng)網(wǎng)絡(luò)塊組成的殘差塊編碼器。
圖2 多尺度特征提取器結(jié)構(gòu)Fig.2 Architecture of multi-scale feature extractor
圖3 最小殘差神經(jīng)網(wǎng)絡(luò)塊與卷積塊Fig.3 Smallest residual block and convolutional block
最小殘差神經(jīng)網(wǎng)絡(luò)塊由1×1 卷積與全局平均池化層(global average pool,GAP)組成,如圖3(b)所示。在最小殘差神經(jīng)網(wǎng)絡(luò)塊中加入轉(zhuǎn)換通道的跳躍連接,以確保特征細(xì)節(jié)信息的充分提取。對于輸入x,經(jīng)過最小殘差神經(jīng)網(wǎng)絡(luò)塊得到的輸出如式(1):
其中,F(xiàn)n是輸出通道為N的兩層1×1 卷積,W是將x轉(zhuǎn)換成通道數(shù)為N的卷積操作,GAP 采樣尺寸為2×2。多尺度特征提取器相鄰分支之間通過跨尺度連接對特征按元素求和,將他們聯(lián)系在一起,避免同一樣本的多尺度特征割裂,同時(shí)將細(xì)節(jié)信息從淺層特征流向深層,增強(qiáng)深層特征的表達(dá)能力。
任務(wù)T={xi|xi∈Sup?Que}提取得到第l級尺度特征f l(xi),如式(2):
在每個(gè)分支網(wǎng)絡(luò)最后使用全局平均池化代替全連接將特征圖轉(zhuǎn)化為特征向量,通過設(shè)置可學(xué)習(xí)參數(shù)作為注意力進(jìn)行L個(gè)尺度特征的融合,得到多尺度融合特征,如式(3):
多尺度特征提取器的3個(gè)分支分別使用2、3、4個(gè)卷積塊作為原始特征編碼器,提取到大小為20×20、10×10以及5×5 像素的原始特征圖。深度學(xué)習(xí)網(wǎng)絡(luò)中,不同深度的特征編碼器提取到的特征攜帶不同比例的細(xì)節(jié)信息和語義信息[19-20],淺層編碼器提取到的特征分辨率高,圖片細(xì)節(jié)保留較多,細(xì)節(jié)信息豐富;而深層編碼器的特征分辨率低,特征更為抽象,含有更多語義信息。對多尺度原始特征使用最小殘差網(wǎng)絡(luò)塊組成的殘差塊編碼器進(jìn)行信息提取。通過殘差塊中的1×1 卷積對原始特征進(jìn)行跨通道信息交互[4,36],提取信息的同時(shí)較大程度保留了原始特征的細(xì)節(jié)特性和語義特性。最后通過GAP將提取到的多尺度特征采樣為相同大小,進(jìn)行多尺度特征融合,得到語義與細(xì)節(jié)信息兼具的多尺度融合特征。
在網(wǎng)絡(luò)增加分支帶來的實(shí)現(xiàn)難度方面,本文提出的多尺度特征提取器,基于ConvNet的卷積塊以及設(shè)計(jì)的最小殘差神經(jīng)網(wǎng)絡(luò)塊,搭建網(wǎng)絡(luò)時(shí),將四層卷積網(wǎng)絡(luò)的卷積塊的數(shù)目從4提升到9,同時(shí)加入了3個(gè)最小殘差神經(jīng)網(wǎng)絡(luò)塊,使網(wǎng)絡(luò)卷積核參數(shù)增加了529 856個(gè),提升了網(wǎng)絡(luò)過擬合的風(fēng)險(xiǎn)。但是另一方面,受文獻(xiàn)[36]的啟發(fā),將ConvNet中與最后一個(gè)卷積鄰接的全連接層替換為全局平均池化層,在不影響網(wǎng)絡(luò)分類性能的基礎(chǔ)上,使其減少了全連接層的819 328 個(gè)參數(shù),最終使多尺度特征提取器所有分支參數(shù)量之和維持在小于ConvNet的水平,從而避免了多尺度特征提取器難以訓(xùn)練的情況。多尺度特征提取器與ConvNet 網(wǎng)絡(luò)塊參數(shù)量對比如表1 所示。但是多尺度特征提取器在實(shí)際實(shí)現(xiàn)時(shí)會加大顯存占用量,并且其中添加的跨尺度連接,增加了網(wǎng)絡(luò)的運(yùn)算量,一定程度降低了效率,因此在實(shí)驗(yàn)部分驗(yàn)證了網(wǎng)絡(luò)的實(shí)時(shí)性。
表1 本文特征提取器與四層卷積網(wǎng)絡(luò)參數(shù)量對比Table 1 Parameters comparison of proposed feature extractor with ConvNet
大部分小樣本模型在分類過程中僅考慮特征的標(biāo)簽信息,并未考慮到特征之間的信息關(guān)聯(lián),而圖結(jié)構(gòu)可以充分挖掘數(shù)據(jù)之間的豐富關(guān)系,通過圖結(jié)點(diǎn)間信息交互增強(qiáng)多尺度融合特征。但是傳統(tǒng)圖更新時(shí)采用無差別更新策略,在更新一個(gè)結(jié)點(diǎn)時(shí)無選擇地使用相鄰結(jié)點(diǎn)。圖神經(jīng)網(wǎng)絡(luò)最早由Gori等人[37]提出,他們構(gòu)建的圖神經(jīng)網(wǎng)絡(luò)中結(jié)點(diǎn)的狀態(tài)取決于3個(gè)因素:結(jié)點(diǎn)自身的標(biāo)簽、相鄰結(jié)點(diǎn)狀態(tài)和相鄰結(jié)點(diǎn)標(biāo)簽。而無差別的更新策略忽略了相鄰結(jié)點(diǎn)標(biāo)簽這一因素,導(dǎo)致非同類信息在同類結(jié)點(diǎn)之間傳播。本文提出選擇性更新策略,通過篩選邊,區(qū)分結(jié)點(diǎn)相似度,實(shí)現(xiàn)在更新時(shí)考慮相鄰結(jié)點(diǎn)標(biāo)簽這一因素。本節(jié)將介紹本文提出的掩碼圖網(wǎng)絡(luò),其結(jié)構(gòu)如圖4。
圖4 掩碼圖網(wǎng)絡(luò)框架在“2-way 2-shot”分類問題上的流程Fig.4 Framework of mask GNN on“2-way 2-shot”classification
首先是圖網(wǎng)絡(luò)的構(gòu)建。將特征提取器輸出的多尺度融合特征構(gòu)建為圖的原始結(jié)點(diǎn)V0i =F(xi),并通過比較標(biāo)簽獲得初始化邊的值e0ij,如式(4)所示。
使用生成的掩碼與邊矩陣按元素相乘,置零冗余和負(fù)增益的邊,從而篩出對圖的更新有增益的邊。根據(jù)邊特征計(jì)算出結(jié)點(diǎn)間需要傳播的信息Inf n Nei→i,如式(6):
將增益信息融入結(jié)點(diǎn)完成一次結(jié)點(diǎn)更新,得到新的結(jié)點(diǎn),如式(7)所示:
式中,λ為超參數(shù)。每一次結(jié)點(diǎn)更新過后,使用點(diǎn)乘注意力[21]重新計(jì)算邊特征用于新一次結(jié)點(diǎn)更新,如式(8)所示:
式中,g1與g2是用于特征轉(zhuǎn)換以更好地度量結(jié)點(diǎn)相似性的線性變換。
傳統(tǒng)圖模型將結(jié)點(diǎn)間的L1距離輸入多層感知機(jī)得到邊特征[35]。這種度量方式通過引入額外的卷積神經(jīng)網(wǎng)絡(luò)或全連接層,讓網(wǎng)絡(luò)自己學(xué)習(xí)輸入結(jié)點(diǎn)特征各個(gè)維度上的權(quán)重,達(dá)到添加注意力的目的。而本文使用的點(diǎn)乘注意力機(jī)制[21],屬于一種乘法注意力,通過結(jié)點(diǎn)特征向量點(diǎn)乘即可求得注意力。由于未引入額外神經(jīng)網(wǎng)絡(luò)進(jìn)行權(quán)重學(xué)習(xí),同時(shí),點(diǎn)乘的計(jì)算在網(wǎng)絡(luò)模型實(shí)施中可批處理為矩陣運(yùn)算,進(jìn)而可以通過高度優(yōu)化的矩陣乘法庫并行地計(jì)算,加快了圖模型推理的速度。此外,本文使用的點(diǎn)乘注意力在傳統(tǒng)乘法注意力的基礎(chǔ)上增加了縮放因子,避免輸出邊特征過大造成的歸一化之后的梯度過小問題。實(shí)驗(yàn)表明,本文的掩碼圖在達(dá)到較高的分類準(zhǔn)確率的同時(shí)有較好的時(shí)間性能。
本文使用增強(qiáng)的多尺度特征計(jì)算類表達(dá)特征,即類原型,并通過距離度量的方式進(jìn)行分類。在類表達(dá)特征的計(jì)算中,原型網(wǎng)絡(luò)[18]基于伯格曼散度思想提出均值類原型,Banerjee 等人[38]證明一組點(diǎn)在特定的空間中如果滿足任意概率分布,這些點(diǎn)的均值點(diǎn)是這個(gè)特定空間中距離這些點(diǎn)平均距離的最小值點(diǎn)。本文認(rèn)為在小樣本情況下,當(dāng)支持集僅有極少數(shù)樣本時(shí),不滿足任意概率分布,不能簡單地通過求均值得到類表達(dá),而應(yīng)評估嵌入特征與真實(shí)類原型之間的距離再計(jì)算。如樣本中存在目標(biāo)遮擋、目標(biāo)僅有部分在圖片中、目標(biāo)過小或過大等,將導(dǎo)致樣本特征遠(yuǎn)離類原型。因此提出一種預(yù)估機(jī)制,通過特征貢獻(xiàn)度改進(jìn)均值型原型,生成更接近真實(shí)分布中心的類原型。
對于類別為m的支持集Supm中的樣本Vi,比較該樣本與Supm中其他樣本的分布情況來獲取貢獻(xiàn)度。具體的,先通過求均值的方法計(jì)算出類Supm在特征空間的偽原型Pm′,如式(9):
使用SoftMax 函數(shù),歸一化類中樣本Vi到所屬類偽原型的距離與其他偽原型距離,得到貢獻(xiàn)度Ci,如式(10):
根據(jù)計(jì)算得到的樣本特征貢獻(xiàn)度,計(jì)算優(yōu)化的類原型Pm,如式(11):
實(shí)驗(yàn)表明,使用特征貢獻(xiàn)度求出的原型可以更好地表達(dá)類特征,效果示意如圖5 所示。圖中共有大象、駱駝和麋鹿3個(gè)類別,分別用綠色、黃色和紅色區(qū)分,每個(gè)類別5個(gè)支持集樣本,1個(gè)測試集樣本。由圖可以看出,當(dāng)部分樣本由于目標(biāo)不明顯或不完整導(dǎo)致其嵌入向量偏離類原型較遠(yuǎn)時(shí),本文利用特征貢獻(xiàn)度求出的類原型相比于均值原型更有代表性,可以避免一些測試樣本分類錯(cuò)誤,如圖中大象、駱駝?lì)悇e。圖中虛線表示使用均值原型分類時(shí)通過度量最近距離得到的分類結(jié)果,實(shí)線為使用本文改進(jìn)原型時(shí)的分類結(jié)果。
圖5 特征貢獻(xiàn)度計(jì)算原型效果示意圖Fig.5 Effect diagram of prototypes computed with feature contribution degree
同時(shí),本文提出了一個(gè)新的互斥損失,在模型學(xué)習(xí)的過程中,促使原型互斥地生成,從而提高度量學(xué)習(xí)能力。損失計(jì)算過程中,使用式(12)的度量機(jī)制度量樣本到原型的距離:
式中,f1與f2是特征轉(zhuǎn)換神經(jīng)網(wǎng)絡(luò)。所屬類別為Supm的樣本Vi的互斥損失,通過度量其與非本類原型的平均距離獲得,如式(13):
式中,N為類別總數(shù),m′表示m之外的類別,τ為溫度參數(shù),用于控制不同樣本損失差異及大小。
當(dāng)前批次所有支持集樣本損失和如式(14):
式中,B為批大小。
支持集樣本與其他類原型之間的距離越小,損失越大;距離越大,損失越小。因此這個(gè)損失使原型盡可能遠(yuǎn)離其他類樣本簇,促使原型互斥。
對測試樣本進(jìn)行分類,計(jì)算查詢集樣本與各類原型的距離,最小的進(jìn)行標(biāo)簽傳播。使用交叉熵?fù)p失作為分類損失,每一次反向傳播損失的分類損失為一次迭代的所有批次樣本的交叉熵?fù)p失和,如式(15)所示:
式中,Yb和Y^b分別表示第b批次查詢集樣本的真實(shí)標(biāo)簽和預(yù)測標(biāo)簽。最終模型每個(gè)訓(xùn)練周期反向傳播總損失為互斥損失與分類損失之和,見式(16):
本文在MiniImagenet、Cifar-100和Caltech-256數(shù)據(jù)集進(jìn)行了5-way 1-shot 與5-way 5-shot 分類任務(wù)的實(shí)驗(yàn),下面分別介紹這3個(gè)數(shù)據(jù)集。
MiniImagenet 數(shù) 據(jù) 集 由Vinyals 等[17]提 出,是 從Imagenet中抽出的子集,專用于小樣本學(xué)習(xí)研究。數(shù)據(jù)集共包含100個(gè)類別,每個(gè)類別包含600張84×84 的彩色圖片。將數(shù)據(jù)集按Ravi等人[27]的設(shè)定劃分:64個(gè)類別用作訓(xùn)練集,16個(gè)用于驗(yàn)證集,20個(gè)用于測試集。
Cifar-100 包含100 個(gè)類別的樣本,每個(gè)類別600 張32×32 的彩色圖片,另外,這100 個(gè)類別來自于20 個(gè)超類。在研究時(shí),劃分60個(gè)類別作為訓(xùn)練集,16個(gè)作為驗(yàn)證集,20個(gè)作為測試集。由于其中樣本分辨率被調(diào)整為32×32,分類任務(wù)難度增大。
Caltech-256 數(shù)據(jù)集包含256 個(gè)類別,共計(jì)30 607 張圖片,這些圖片都下載自谷歌圖片,并手工篩除了不合類別要求的圖片。數(shù)據(jù)集中每個(gè)類別最少80 個(gè)樣本,并引入了一個(gè)新的更大的復(fù)雜類別來測試背景誤判能力。
實(shí)驗(yàn)中,對于一些超參數(shù)以及其他實(shí)驗(yàn)設(shè)置如下:模型訓(xùn)練階段,使用自適應(yīng)矩估計(jì)算法(adaptive moment estimation)優(yōu)化模型參數(shù),并設(shè)置初始學(xué)習(xí)率為1×10-3,權(quán)重衰減為1×10-6。對MiniImagenet和Cifar-100數(shù)據(jù)集,每經(jīng)過15 000 個(gè)訓(xùn)練周期學(xué)習(xí)率衰減為一半,共訓(xùn)練100 000個(gè)周期,Caltech-256數(shù)據(jù)集則由于樣本較少,設(shè)置為12 000個(gè)周期學(xué)習(xí)率衰減一半,共訓(xùn)練84 000個(gè)周期。本文進(jìn)行了5-way 1-shot實(shí)驗(yàn)與5-way 5-shot實(shí)驗(yàn),訓(xùn)練時(shí)批大小分別設(shè)置為40 與20,即分別將40 與20個(gè)任務(wù)同時(shí)計(jì)算損失,用于反向傳播。掩碼圖網(wǎng)絡(luò)進(jìn)行3次更新且更新超參數(shù)λ取值0.5,互斥損失中溫度參數(shù)τ設(shè)置為0.8。在驗(yàn)證與測試階段,隨機(jī)抽取每類15個(gè)樣本作為查詢集。所有實(shí)驗(yàn)均在NvidiaRTX 2080Ti上完成。
3.3.1 與基于度量學(xué)習(xí)方法進(jìn)行對比
將本文方法和基于度量學(xué)習(xí)的經(jīng)典模型MN[17]、PN[18]、TEAM(transductive episodic-wise adaptive metric)[39]在MiniImagenet、Cifar-100、Caltech-256 數(shù)據(jù)集上進(jìn)行5-way 1-shot 和5-way 5-shot 任務(wù)的對比實(shí)驗(yàn),結(jié)果如表2 所示。實(shí)驗(yàn)結(jié)果表明,與經(jīng)典的度量方法比較,本文在5-way 1-shot和5-way 5-shot分類任務(wù)上都有較大的分類準(zhǔn)確率提升,說明本文方法優(yōu)化了度量學(xué)習(xí)的結(jié)果,能有效用于小樣本學(xué)習(xí)的分類任務(wù)。
表2 度量學(xué)習(xí)方法在各數(shù)據(jù)集上的5-way分類結(jié)果Table 2 5-way classification results of metric learning methods on different datasets %
為了更直觀地顯示本文各部分對度量學(xué)習(xí)的優(yōu)化效果,使用t-SNE(t-distributed stochastic neighbor embedding)[40]可視化了初始測試樣本、多尺度特征提取器提取到的多尺度融合特征以及掩碼圖優(yōu)化后的多尺度融合特征,如圖6所示。
圖中圓點(diǎn)表示支持集樣本,叉號表示查詢集樣本,不同顏色表示不同的標(biāo)簽,圖6(a)、(b)、(c)依次為初始樣本、多尺度融合特征和掩碼圖優(yōu)化后的特征。由圖可以看出原始樣本映射到二維空間后,不同類別樣本混雜在一起,無法進(jìn)行有效區(qū)分;多尺度融合特征相比于原始樣本已經(jīng)有一定程度的聚類,說明本文的多尺度特征提取器有利于度量學(xué)習(xí)的進(jìn)行;掩碼圖增強(qiáng)的特征則在二維空間中分簇明顯,并且增強(qiáng)的查詢集樣本距離所屬類別的支持集簇很近,能很好地用于分類,證明了方法的有效性。
圖6 樣本(特征)的t-SNE可視化Fig.6 t-SNE visualization of samples(features)
3.3.2 與基于圖方法進(jìn)行對比
基于圖的方法比較經(jīng)典的有Liu 等人提出的TPN(transductive propagation network)[34]和Kim 等 人 提 出 的EGNN(edge-labeling graph neural network)[35]。TPN通過ConvNet 提取所有樣本特征并用這些特征構(gòu)建出一個(gè)圖結(jié)構(gòu),在標(biāo)簽傳播階段通過轉(zhuǎn)導(dǎo)推理的方式完成標(biāo)簽傳播。EGNN提取特征的過程與TPN相同,而EGNN構(gòu)建出的圖結(jié)構(gòu)中加入了邊特征來表示邊連接的兩個(gè)結(jié)點(diǎn)的相似程度。構(gòu)建好這種圖結(jié)構(gòu)后,在其上進(jìn)行數(shù)次結(jié)點(diǎn)和邊的更新,根據(jù)最終獲得的邊特征來判斷兩個(gè)樣本屬于同一類的概率。
本文和TPN、EGNN在不同數(shù)據(jù)集上進(jìn)行5-way 1-shot和5-way 5-shot 任務(wù)的對比實(shí)驗(yàn),結(jié)果如表3 所示。由表3 可以看出,本文方法在MiniImagenet 和Caltech-256數(shù)據(jù)集上提升明顯;而在Cifar-100 數(shù)據(jù)集上,相比于EGNN 的分類性能提升不明顯。這是由于該訓(xùn)練集中樣本尺寸過小,使用多尺度特征提取器與ConvNet相比優(yōu)勢不明顯。說明本文方法在樣本原始特征充足情況下,可以更充分提取并利用豐富的特征信息,有效提高分類準(zhǔn)確率;而在樣本本身特征不夠充裕的情況下,僅能較小提升特征的提取,略優(yōu)于傳統(tǒng)圖模型。
表3 基于圖的方法在不同數(shù)據(jù)集上的5-way分類準(zhǔn)確率Table 3 5-way classification accuracies of GNN methods on different datasets %
此外,在二維空間可視化了本文掩碼圖經(jīng)過選擇性更新得到的特征,并與EGNN無差別更新策略得到的特征進(jìn)行了對比,如圖7 所示。圖7(a)為EGNN 的特征,圖7(b)為本文增強(qiáng)特征。由圖可以看出,EGNN中支持集樣本的簇相互之間的距離較近,查詢集樣本離對應(yīng)支持集樣本簇比較遠(yuǎn);而本文中支持集樣本的簇相互之間的距離大,查詢集樣本大部分都在同一類別支持集的簇中。這說明本文的選擇性更新策略緩解了類內(nèi)信息的類間傳播問題,使不同類特征易于區(qū)分。
圖7 選擇性更新與無差別更新策略對比Fig.7 Comparision of selective update with undifferentiated update strategies
3.3.3 與其他方法進(jìn)行對比
除了基于度量學(xué)習(xí)和基于圖的方法,常用于小樣本學(xué)習(xí)研究的還有元學(xué)習(xí)和數(shù)據(jù)增強(qiáng)方法。其中元學(xué)習(xí)效果較好的有FEAT(few-shot embedding adaptation with transformer)[41]和DTN(diversity transfer network)[42]。FEAT 提出了一種自適應(yīng)轉(zhuǎn)換特征的方法,使特征變?yōu)槿蝿?wù)相關(guān),增強(qiáng)泛化能力。DTN 通過新的有效的元分類損失進(jìn)行類間樣本多樣性的學(xué)習(xí)。數(shù)據(jù)增強(qiáng)方法中,效果較好的有通過語義進(jìn)行增強(qiáng)的Dual TriNet[24]。
將本文方法與這些經(jīng)典方法在不同數(shù)據(jù)集上進(jìn)行5-way 1-shot 和5-way 5-shot 任務(wù)的對比實(shí)驗(yàn),結(jié)果如表4所示。實(shí)驗(yàn)結(jié)果顯示,與基于元學(xué)習(xí)的模型MAML[15]、DTN、FEAT相比,本文方法的性能都有較大提升。由于數(shù)據(jù)增強(qiáng)的模型Dual TriNet使用更深的ResNet作為骨干網(wǎng)絡(luò),使其在訓(xùn)練樣本較少的Caltech-256數(shù)據(jù)集上有更好的分類效果。
表4 與其他方法在各數(shù)據(jù)集上的5-way分類結(jié)果對比Table 4 Comparison of 5-way classification results with other methods on different datasets %
值得注意的是,本文方法在部分?jǐn)?shù)據(jù)集上結(jié)果優(yōu)于Dual TriNet,在MniImagenet 上5-way 1-shot 與5-way 5-shot 分類結(jié)果分別提高3.3 個(gè)百分點(diǎn)與1.7 個(gè)百分點(diǎn),在Cifar-100 上分別提升3.0 個(gè)百分點(diǎn)與3.9 個(gè)百分點(diǎn)。證明了本文方法對特征信息的提取能力與ResNet模型有較強(qiáng)競爭力。
3.3.4 效率分析
由于本文在傳統(tǒng)度量學(xué)習(xí)的基礎(chǔ)上,構(gòu)建了較為復(fù)雜的圖神經(jīng)網(wǎng)絡(luò),并且在圖中添加了元學(xué)習(xí)器,增加了網(wǎng)絡(luò)的參數(shù)量,提高了網(wǎng)絡(luò)復(fù)雜度。為驗(yàn)證本文方法的執(zhí)行效率,對其參數(shù)進(jìn)行了分析,對其運(yùn)行速度進(jìn)行了實(shí)驗(yàn),并與不同方法進(jìn)行對比分析。
本文采用5-way 1-shot的設(shè)定進(jìn)行對比實(shí)驗(yàn),每類隨機(jī)抽取15個(gè)樣本作為查詢集,批大小設(shè)置為40,結(jié)果如表5 所示。MN[17]和PN[18]使用四層卷積做嵌入網(wǎng)絡(luò),參數(shù)量低(1.2×106),方法實(shí)時(shí)性好。RN[33]使用了更復(fù)雜的嵌入神經(jīng)網(wǎng)絡(luò),運(yùn)算復(fù)雜度較高,效率較低。EGNN[35]使用四層卷積作為特征提取器,其構(gòu)建的圖神經(jīng)網(wǎng)絡(luò)中結(jié)點(diǎn)與邊的更新通過卷積神經(jīng)網(wǎng)絡(luò)和全連接層實(shí)現(xiàn),模型參數(shù)量大,效率較低。本文方法多尺度特征提取器參數(shù)量(9.0×105)小于四層卷積,而掩碼圖部分,更新時(shí)使用的點(diǎn)乘注意力未增加參數(shù)量,元學(xué)習(xí)器參數(shù)量為4.0×105。整個(gè)網(wǎng)絡(luò)參數(shù)量為1.3×106,與PN相近,但是本文跨尺度連接、圖更新、掩碼生成等操作增加了計(jì)算量。實(shí)驗(yàn)時(shí)發(fā)現(xiàn)掩碼圖更新次數(shù)的改變對分類準(zhǔn)確率的影響較大,但是對耗時(shí)的影響輕微,因此采用分類效果最好的3次更新與其他方法進(jìn)行對比,此時(shí)本文方法耗時(shí)約為PN的1.2倍,但是分類準(zhǔn)確率相比于PN提升明顯。此外,本文方法在獲得更高準(zhǔn)確率時(shí)的效率仍高于EGNN,說明本文方法在達(dá)到較高分類性能的情況下保持較高效率。
表5 不同方法效率對比Table 5 Efficiency comparison of different methods
為了驗(yàn)證本文方法的有效性,并對本文方法中多尺度特征提取器、掩碼圖網(wǎng)絡(luò)、特征貢獻(xiàn)度和互斥損失三部分的效果有進(jìn)一步了解,在MiniImagenet、Cifar-100、Caltech-256 數(shù)據(jù)集上進(jìn)行了消融實(shí)驗(yàn)的研究,結(jié)果如表6 所示。本文模型在原型類度量網(wǎng)絡(luò)的基礎(chǔ)上融入多尺度特征模塊以及掩碼圖模塊,因此消融實(shí)驗(yàn)采用PN[18]作為對比的基準(zhǔn)方法。
3.4.1 多尺度特征提取器的效果
如表6 所示,使用“PN+多尺度特征”的分類準(zhǔn)確率與PN 相比在MiniImagenet 上1-shot 與5-shot 分別提升5.1 個(gè)百分點(diǎn)與6.1 個(gè)百分點(diǎn),在Cifar-100 上分別提升4.7個(gè)百分點(diǎn)與1.4個(gè)百分點(diǎn),證明多尺度特征提取模塊提取到的特征比單一尺度特征更為有效。類似的,“PN+多尺度特征+掩碼圖”與僅使用掩碼圖相比,在MiniImagenet上1-shot與5-shot都提升了1.6個(gè)百分點(diǎn),在Cifar-100上分別提升1.6個(gè)百分點(diǎn)與1.7個(gè)百分點(diǎn),這說明多尺度特征提取器提取到的信息在圖結(jié)構(gòu)中是可傳播的,并且這些信息對分類起到了積極作用。
表6 本文方法在不同數(shù)據(jù)集上的5-way消融實(shí)驗(yàn)Table 6 5-way abalation experiment of our methods on different datasets %
3.4.2 掩碼圖網(wǎng)絡(luò)的效果
如表6 所示,使用“PN+掩碼圖”與PN 相比,1-shot和5-shot 在MiniImagenet 上分別提升了9.6 個(gè)百分點(diǎn)與7.7 個(gè)百分點(diǎn),在Cifar-100 上分別提升7.8 個(gè)百分點(diǎn)與4.0 個(gè)百分點(diǎn);“PN+多尺度特征+掩碼圖”與僅使用多尺度特征相比,1-shot 和5-shot 在MiniImagenet 上分別提升了6.1個(gè)百分點(diǎn)與3.2個(gè)百分點(diǎn),在Cifar-100上分別提升4.7 個(gè)百分點(diǎn)與4.3 個(gè)百分點(diǎn)。這表明掩碼圖對特征信息挖掘以及對特征的增強(qiáng)具有很好的效果。
為了進(jìn)一步了解本文掩碼圖網(wǎng)絡(luò)的有效性,本文可視化了掩碼圖中的邊特征,如圖8 所示。圖中矩陣為5個(gè)支持集樣本與對應(yīng)類別的5 個(gè)查詢集樣本兩兩之間的邊特征,圖8(a)~(d)分別為第1、2、3次掩碼圖更新后的邊和邊的真實(shí)值。矩陣中不同顏色表示不同邊特征值,樣本越相似,邊特征值越大越接近紅色,反之,邊特征越小越接近藍(lán)色。如圖所示,經(jīng)過掩碼圖的更新,邊特征快速向真值變化,并且每次更新過后,邊特征的差異程度也發(fā)生變化,需要元學(xué)習(xí)器動(dòng)態(tài)學(xué)習(xí)這種變化,增強(qiáng)泛化能力。
圖8 掩碼圖中邊特征的可視化Fig.8 Visualization of edge features in mask GNN
但是當(dāng)同類樣本差異較大的情況下,本文的元學(xué)習(xí)器在生成邊掩碼時(shí)會將同類樣本判斷為不同類別,從而切斷它們之間的信息傳播,造成掩碼圖網(wǎng)絡(luò)失效,如圖9所示。圖9(a)為支持集樣本,圖9(b)為查詢集樣本,圖9(c)為它們對應(yīng)邊特征矩陣的可視化。同一列樣本屬于同一類,由左至右分別為鍵盤、蚌、電腦顯示屏、筆記本電腦、保齡球。由圖9 可以看出,當(dāng)同類樣本不相似,而不同類樣本相似時(shí),邊特征更新結(jié)果與真值相差較大。說明本文元學(xué)習(xí)器區(qū)分結(jié)點(diǎn)間信息能否用于傳播時(shí)對樣本差異度有很強(qiáng)的依賴性,這是由于元學(xué)習(xí)器的輸入為樣本差異矩陣導(dǎo)致的。
圖9 困難任務(wù)邊特征的可視化Fig.9 Visualization of edge features of difficult task
3.4.3 特征貢獻(xiàn)度和互斥損失的效果
特征貢獻(xiàn)度和互斥損失應(yīng)用于“PN+掩碼圖”與“PN+多尺度特征+掩碼圖”,模型分類效果也都有提升,證明了本文對均值類原型計(jì)算方法改進(jìn)的有效性。圖10是本文互斥損失添加前后模型訓(xùn)練階段損失的變化情況。
圖10 損失函數(shù)曲線Fig.10 Loss curves
由圖10 可以看出,本文組合損失在訓(xùn)練初始階段為1.3,高于交叉熵?fù)p失的0.8。這是由于本文組合損失在交叉熵?fù)p失的基礎(chǔ)上加上了互斥損失,訓(xùn)練開始階段交叉熵?fù)p失相差不大的情況下本文組合損失的值更大。隨后兩者都開始下降,在42 000個(gè)Epoch時(shí)本文組合損失開始低于交叉熵?fù)p失,并在之后的訓(xùn)練中保持較低值直到收斂。這是由于本文的互斥損失促進(jìn)了原型互斥地生成,提升分類準(zhǔn)確率,降低了交叉熵?fù)p失,使總損失維持在較低水平。本文組合損失在70 000 個(gè)Epoch 時(shí)開始收斂,而交叉熵?fù)p失在50 000 個(gè)Epoch 時(shí)開始收斂,相比于本文收斂更快。分析原因,是由于組合損失更為復(fù)雜,在模型參數(shù)量不變的情況下,加入的互斥損失在模型訓(xùn)練的前中期維持在較高值,延緩了收斂速度,在模型訓(xùn)練后期降到較低值。消融實(shí)驗(yàn)結(jié)果及損失曲線表明提出的互斥損失優(yōu)化了模型訓(xùn)練,使模型分類性能提高。
本文提出的模型基于度量學(xué)習(xí)與元學(xué)習(xí)方法,致力于解決小樣本分類訓(xùn)練樣本過少導(dǎo)致的可用于模型訓(xùn)練信息不足的問題。傳統(tǒng)度量學(xué)習(xí)存在僅使用頂層特征造成的信息單一的問題,對此本文設(shè)計(jì)了一種多尺度特征提取器,使模型提取到樣本信息更為豐富的多尺度融合特征,用于模型后續(xù)分類;傳統(tǒng)基于圖神經(jīng)網(wǎng)絡(luò)的方法在處理小樣本分類時(shí)存在結(jié)點(diǎn)無差別更新的問題,對此本文結(jié)合元學(xué)習(xí)機(jī)制以生成掩碼的方式進(jìn)行圖結(jié)點(diǎn)的選擇性更新,掩碼圖通過結(jié)點(diǎn)間更為有效的信息交互進(jìn)一步增強(qiáng)了多尺度融合特征;此外,本文提出特征貢獻(xiàn)度和互斥損失對均值類原型求解過程進(jìn)行改進(jìn),以更好地利用增強(qiáng)的多尺度特征進(jìn)行分類。本文在Mini-Imagenet、Caltech-256和Cifar-100數(shù)據(jù)集上與傳統(tǒng)模型及較先進(jìn)模型進(jìn)行比較,在MiniImagenet 上,傳統(tǒng)方法1-shot 準(zhǔn)確率為49.4%,5-shot 準(zhǔn)確率為68.2%,本文方法分別為61.4%和78.6%,分別超過傳統(tǒng)方法12.0 個(gè)百分點(diǎn)與10.4 個(gè)百分點(diǎn)。實(shí)驗(yàn)表明本文方法相比于傳統(tǒng)方法有了較大提升,并達(dá)到了先進(jìn)水平。
本文方法還存在以下不足需要進(jìn)一步研究:(1)多尺度特征在融合時(shí)在各尺度使用了單一的可學(xué)習(xí)參數(shù)作為注意力機(jī)制進(jìn)行多尺度特征的融合,導(dǎo)致各尺度部分關(guān)鍵信息被弱化,部分干擾信息被強(qiáng)化,影響分類效果。后續(xù)考慮實(shí)現(xiàn)能區(qū)分單一尺度上信息重要性的注意力機(jī)制。(2)元學(xué)習(xí)器識別結(jié)點(diǎn)間信息有效性時(shí)對樣本差異度有過強(qiáng)的依賴性。為了進(jìn)一步提高元學(xué)習(xí)器識別有效信息的能力,減弱其對樣本差異度的依賴,考慮改進(jìn)元學(xué)習(xí)器的結(jié)構(gòu)。強(qiáng)化學(xué)習(xí)可以通過反饋機(jī)制來訓(xùn)練智能體,未來的研究中,考慮將本文元學(xué)習(xí)器與強(qiáng)化學(xué)習(xí)結(jié)合到一起,并通過困難樣本的強(qiáng)化訓(xùn)練,增強(qiáng)元學(xué)習(xí)器對困難任務(wù)的處理能力。