劉 兵,楊 娟,汪榮貴,薛麗霞
合肥工業(yè)大學(xué) 計(jì)算機(jī)與信息學(xué)院,合肥 230601
近年來(lái),深度學(xué)習(xí)在圖像分類[1-3]、目標(biāo)檢測(cè)[4-6]、機(jī)器翻譯[7-8]等人工智能任務(wù)上取得了卓越的進(jìn)展。在某些圖像領(lǐng)域中,深度模型的分類、檢測(cè)、識(shí)別能力已經(jīng)接近甚至超越了人類。然而這些成就無(wú)不依賴于一個(gè)限制,即訓(xùn)練一個(gè)有效的深度模型需要大量的帶標(biāo)簽樣本。當(dāng)訓(xùn)練樣本不足時(shí),深度模型很容易會(huì)產(chǎn)生過(guò)擬合問(wèn)題,導(dǎo)致學(xué)習(xí)失敗。對(duì)于深度模型而言通過(guò)少量樣本進(jìn)行學(xué)習(xí)是相當(dāng)困難的任務(wù),但對(duì)于人類來(lái)說(shuō)卻是十分簡(jiǎn)單的過(guò)程,例如,即使是從未見(jiàn)過(guò)“老虎”形象的孩子僅通過(guò)幾張“老虎”圖片就能習(xí)得“老虎”的視覺(jué)概念。受此啟發(fā),小樣本學(xué)習(xí)應(yīng)運(yùn)而生。小樣本學(xué)習(xí)任務(wù)是指在訓(xùn)練過(guò)程中,對(duì)于從未見(jiàn)過(guò)的類別,在給出這些類別的少量樣本的情況下,仍然能學(xué)習(xí)到一個(gè)具有優(yōu)秀辨別力的分類器。由于小樣本學(xué)習(xí)關(guān)注于深度模型在樣本量受限的情況下的學(xué)習(xí)問(wèn)題,故小樣本學(xué)習(xí)的應(yīng)用非常廣泛,例如,對(duì)醫(yī)療影像中的罕見(jiàn)病例進(jìn)行識(shí)別分類來(lái)輔助診斷,或是在海量監(jiān)控視頻中對(duì)嫌疑人進(jìn)行搜索識(shí)別來(lái)輔助偵察等。這些任務(wù)之間存在著明顯的共性,即僅通過(guò)幾張帶標(biāo)簽的樣本來(lái)學(xué)習(xí)一個(gè)有效的分類器,而不需要百萬(wàn)乃至千萬(wàn)級(jí)別的標(biāo)注數(shù)據(jù)。除此之外,小樣本學(xué)習(xí)還極大地減輕了樣本標(biāo)注的工作量。故旨在通過(guò)少量的標(biāo)注樣本來(lái)習(xí)得一個(gè)新視覺(jué)概念的小樣本學(xué)習(xí),正在吸引更多的關(guān)注。
小樣本學(xué)習(xí)方法大致分為三種:數(shù)據(jù)增強(qiáng)、度量學(xué)習(xí)以及元學(xué)習(xí)。由于小樣本學(xué)習(xí)中最根本的問(wèn)題是訓(xùn)練樣本的匱乏,故數(shù)據(jù)增強(qiáng)方法嘗試借助一些額外信息來(lái)擴(kuò)充當(dāng)前的訓(xùn)練數(shù)據(jù)。一些典型方法[9]提出將樣本映射到特征域,并在特征域進(jìn)行增強(qiáng),如將數(shù)據(jù)映射到語(yǔ)義空間,借助額外的語(yǔ)義信息對(duì)數(shù)據(jù)進(jìn)行增強(qiáng)。然而數(shù)據(jù)增強(qiáng)方法生成的樣本與原樣本之間存在視覺(jué)相似性,很難從根本上解決模型因樣本不足而產(chǎn)生的過(guò)擬合問(wèn)題。度量學(xué)習(xí)方法的思想比較直接,它通過(guò)學(xué)習(xí)樣本與特征之間的映射關(guān)系,將樣本映射到一個(gè)公共的特征空間,在這個(gè)特征空間中,樣本之間采用相似度進(jìn)行度量,同類別之間的樣本距離較近,不同類別之間的樣本距離較遠(yuǎn),查詢樣本通過(guò)尋找最近鄰的類別來(lái)實(shí)現(xiàn)分類。然而由于極端的樣本量限制,學(xué)習(xí)一個(gè)高質(zhì)量的特征空間相當(dāng)困難。與度量學(xué)習(xí)方法不同,元學(xué)習(xí)方法是任務(wù)級(jí)別的方法。元學(xué)習(xí)方法通過(guò)基礎(chǔ)學(xué)習(xí)器與元學(xué)習(xí)器之間的協(xié)調(diào)工作,來(lái)得到一個(gè)最優(yōu)的參數(shù)狀態(tài),模型基于該參數(shù)狀態(tài)僅需少量樣本的迭代學(xué)習(xí)即可習(xí)得新類別。具體而言,在學(xué)習(xí)過(guò)程中,基礎(chǔ)學(xué)習(xí)器從每個(gè)獨(dú)立任務(wù)中快速地獲取知識(shí)并將這些知識(shí)傳遞給元學(xué)習(xí)器,元學(xué)習(xí)器通過(guò)積累大量任務(wù)上的總體知識(shí)來(lái)達(dá)到一個(gè)最優(yōu)的參數(shù)狀態(tài),用該參數(shù)狀態(tài)來(lái)更新基礎(chǔ)學(xué)習(xí)器的參數(shù),此時(shí)基礎(chǔ)學(xué)習(xí)器便擁有了在新任務(wù)上快速學(xué)習(xí)的能力,因而適合小樣本學(xué)習(xí)。
元學(xué)習(xí)中一個(gè)代表性的方法是MAML(modelagnostic meta-learning)[10]。在MAML的一輪迭代中,基礎(chǔ)學(xué)習(xí)器學(xué)習(xí)多個(gè)獨(dú)立任務(wù),并在每個(gè)任務(wù)中將由誤差產(chǎn)生的梯度信息傳遞給元學(xué)習(xí)器,元學(xué)習(xí)器通過(guò)累積這些任務(wù)上的梯度信息來(lái)獲得經(jīng)驗(yàn)風(fēng)險(xiǎn)最小化的梯度,進(jìn)而優(yōu)化元學(xué)習(xí)器參數(shù),最后再用更新后的元學(xué)習(xí)器參數(shù)初始化基礎(chǔ)學(xué)習(xí)器,來(lái)進(jìn)行下一輪的迭代。由于元學(xué)習(xí)器學(xué)習(xí)大量任務(wù)的整體知識(shí),故元學(xué)習(xí)器的參數(shù)在收斂時(shí)便具備了在各個(gè)任務(wù)上快速泛化的能力,此時(shí)用元學(xué)習(xí)器的參數(shù)來(lái)初始化基礎(chǔ)學(xué)習(xí)器,通過(guò)若干次微調(diào),基礎(chǔ)學(xué)習(xí)器便可以快速學(xué)習(xí)新任務(wù)。
然而,MAML 通過(guò)微調(diào)基礎(chǔ)學(xué)習(xí)器的學(xué)習(xí)方式依然困難重重,認(rèn)為有兩個(gè)因素限制MAML 的有效性。首先,在小樣本學(xué)習(xí)中,基礎(chǔ)學(xué)習(xí)器雖然具備在新任務(wù)上快速泛化的能力,但是相對(duì)于需要微調(diào)的參數(shù)量,用來(lái)微調(diào)的樣本仍然過(guò)少,這會(huì)導(dǎo)致微調(diào)效果不佳。其次,基礎(chǔ)學(xué)習(xí)器的初始化狀態(tài)是元學(xué)習(xí)器在大量任務(wù)上習(xí)得的總體趨勢(shì),而在具體任務(wù)中并不是最優(yōu)的狀態(tài),尤其是當(dāng)新任務(wù)偏離總體趨勢(shì)時(shí),基礎(chǔ)學(xué)習(xí)器的初始化狀態(tài)并不足以支撐其在新任務(wù)上的微調(diào)。
為了解決這兩個(gè)問(wèn)題,提出了基于記憶的遷移學(xué)習(xí)方法,本文提出以下創(chuàng)新點(diǎn):
(1)提出了權(quán)重分解策略來(lái)解決基礎(chǔ)學(xué)習(xí)器中需要微調(diào)的參數(shù)量過(guò)多的問(wèn)題。具體而言,將部分卷積層的權(quán)重進(jìn)行分解,將具有泛化能力的部分作為泛化權(quán)重,對(duì)任務(wù)敏感的部分作為敏感權(quán)重。將預(yù)訓(xùn)練權(quán)重固定下來(lái)作為泛化權(quán)重,僅微調(diào)敏感權(quán)重來(lái)學(xué)習(xí)新任務(wù)。這樣的分解策略在小樣本學(xué)習(xí)中很有意義,泛化權(quán)重可以保證模型在學(xué)習(xí)過(guò)程中始終具備泛化能力,避免過(guò)擬合的發(fā)生,與此同時(shí),需要微調(diào)的參數(shù)量極大地減少了,意味著學(xué)習(xí)過(guò)程更容易收斂。
(2)借助一個(gè)額外的記憶模塊來(lái)更有效地初始化基礎(chǔ)學(xué)習(xí)器。在元學(xué)習(xí)階段,模型通過(guò)微調(diào)敏感權(quán)重來(lái)學(xué)習(xí)新任務(wù),借助一個(gè)記憶模塊將任務(wù)信息與敏感權(quán)重信息關(guān)聯(lián)起來(lái),并存儲(chǔ)于記憶中,隨著網(wǎng)絡(luò)對(duì)任務(wù)的學(xué)習(xí)進(jìn)行同步更新。每當(dāng)遇到新任務(wù)時(shí),記憶模塊便根據(jù)當(dāng)前任務(wù)在記憶中查找最相關(guān)的任務(wù),來(lái)獲取最相關(guān)任務(wù)對(duì)應(yīng)的敏感權(quán)重,并用該敏感權(quán)重來(lái)初始化當(dāng)前模型的參數(shù),使得初始化參數(shù)具備任務(wù)相關(guān)性,有效地避免了新任務(wù)偏離整體趨勢(shì)時(shí),初始化狀態(tài)難以支撐微調(diào)的情況,進(jìn)一步提升微調(diào)效果。
小樣本學(xué)習(xí)問(wèn)題取得的研究進(jìn)展主要為以下幾個(gè)方向:數(shù)據(jù)增強(qiáng),度量學(xué)習(xí)以及元學(xué)習(xí)。最近,轉(zhuǎn)導(dǎo)學(xué)習(xí)同樣獲得了不少的關(guān)注由于其在分類效果上的提升。
小樣本學(xué)習(xí)的關(guān)鍵問(wèn)題是缺乏足夠的訓(xùn)練樣本,故數(shù)據(jù)增強(qiáng)方法是解決小樣本問(wèn)題最自然的方法。由于標(biāo)準(zhǔn)的數(shù)據(jù)增強(qiáng)方法,如裁剪、旋轉(zhuǎn)、加噪等方法,生成的圖片與原始圖片存在著極大的視覺(jué)相似性,很難在小樣本學(xué)習(xí)中起作用,故在小樣本問(wèn)題中,數(shù)據(jù)增強(qiáng)方法往往需要借助額外數(shù)據(jù),從中獲取可遷移的知識(shí)來(lái)擴(kuò)充訓(xùn)練數(shù)據(jù)。Zhang等人[11]提出通過(guò)兩個(gè)特征提取器分別提取圖像的前景特征與背景特征,通過(guò)將不同的前景與背景進(jìn)行組合,來(lái)生成更多的合成圖像,以此實(shí)現(xiàn)數(shù)據(jù)集的擴(kuò)充。Wang 等人[12]則是在特征域構(gòu)建一個(gè)生成器,通過(guò)對(duì)特征進(jìn)行加噪,來(lái)生成新的實(shí)例。Chen 等人[9]進(jìn)一步將視覺(jué)特征映射到語(yǔ)義空間,在語(yǔ)義空間中借助語(yǔ)義信息進(jìn)行數(shù)據(jù)擴(kuò)充,通過(guò)將擴(kuò)充數(shù)據(jù)映射回視覺(jué)空間來(lái)獲得更多的擴(kuò)充樣例。
度量學(xué)習(xí)方法將樣本映射到一個(gè)低維的嵌入空間中,在這個(gè)空間中樣本的特征變得更具有辨別力,通過(guò)度量的方法對(duì)樣本進(jìn)行分類。Koch 等人[13]借助孿生網(wǎng)絡(luò)結(jié)構(gòu)來(lái)學(xué)習(xí)兩個(gè)輸入樣本之間的相似度,通過(guò)對(duì)相似度分?jǐn)?shù)排序來(lái)實(shí)現(xiàn)分類。Vinyals 等人[14]在樣本到嵌入空間的映射過(guò)程中加入了注意力機(jī)制,并首次提出了episodes 訓(xùn)練策略。Snell 等人[15]提出在嵌入空間中,各個(gè)類別都以各自的類別原型表示,查詢樣本通過(guò)計(jì)算與各個(gè)類別原型之間的歐式距離作為與各個(gè)類別的相似度,最后基于相似度進(jìn)行分類。Sung等人[16]采用神經(jīng)網(wǎng)絡(luò)來(lái)學(xué)習(xí)距離度量的方式而不是固定的度量方式。Li等人[17]通過(guò)比較查詢樣本與各個(gè)類別之間的局部描述子來(lái)尋找最接近的類別。Li等人[18]使用協(xié)方差矩陣來(lái)表示每個(gè)類別,同時(shí)提出協(xié)方差度量方式來(lái)進(jìn)行距離的度量。
元學(xué)習(xí)方法通過(guò)學(xué)習(xí)一系列相關(guān)任務(wù)來(lái)歸納出這些任務(wù)的本質(zhì)規(guī)律,當(dāng)面對(duì)新的任務(wù)時(shí),可以根據(jù)習(xí)得的知識(shí)快速擬合與泛化。Finn 等人[10]通過(guò)元學(xué)習(xí)器積累任務(wù)的總體趨勢(shì),以此來(lái)更新基礎(chǔ)學(xué)習(xí)器的參數(shù),使得基礎(chǔ)學(xué)習(xí)器在遇到新的學(xué)習(xí)任務(wù)時(shí),具有快速擬合的能力。Ravi 等人[19]用基于LSTM 的元學(xué)習(xí)器來(lái)模擬梯度下降的過(guò)程,LSTM通過(guò)其細(xì)胞狀態(tài)來(lái)更新分類器網(wǎng)絡(luò)的參數(shù),最后在新的任務(wù)中來(lái)指導(dǎo)分類器網(wǎng)絡(luò)的更新。Li 等人[20]提出不僅學(xué)習(xí)基礎(chǔ)學(xué)習(xí)器的梯度下降過(guò)程,而且還學(xué)習(xí)基礎(chǔ)學(xué)習(xí)器的更新方向和學(xué)習(xí)速率。
轉(zhuǎn)導(dǎo)學(xué)習(xí)方法提出將所有的待預(yù)測(cè)樣本送入網(wǎng)絡(luò)并同時(shí)進(jìn)行預(yù)測(cè),以此來(lái)學(xué)習(xí)所有樣本(包含帶標(biāo)簽樣本與不帶標(biāo)簽樣本)之間的關(guān)系。Liu 等人[21]利用所有的樣本來(lái)進(jìn)行轉(zhuǎn)導(dǎo)推理,在這一過(guò)程中將標(biāo)簽從帶標(biāo)簽的樣本傳遞至不帶標(biāo)簽的樣本。Ye 等人[22]提出通過(guò)自注意力機(jī)制將任務(wù)無(wú)關(guān)的樣本特征轉(zhuǎn)換成任務(wù)相關(guān)的樣本特征,從而更好地進(jìn)行分類。Li等人[23]借助樣本實(shí)例與其鄰居實(shí)例之間的關(guān)系來(lái)實(shí)現(xiàn)對(duì)該樣本的特征增強(qiáng)。
遷移學(xué)習(xí)的目標(biāo)是將在某些任務(wù)上學(xué)習(xí)到的知識(shí)或經(jīng)驗(yàn)應(yīng)用到不同但相關(guān)的任務(wù)中。對(duì)于深度模型,一種行之有效的遷移學(xué)習(xí)方法是將預(yù)訓(xùn)練模型應(yīng)用于新任務(wù),稱之為微調(diào)。在小樣本學(xué)習(xí)中,遷移學(xué)習(xí)通過(guò)微調(diào)在大量任務(wù)上預(yù)訓(xùn)練的模型實(shí)現(xiàn)在新任務(wù)上快速學(xué)習(xí),這些任務(wù)之間需要存在一定的相關(guān)性,如共享的特征、相似的語(yǔ)義屬性或是相關(guān)的上下文信息。在度量學(xué)習(xí)[14-15]中,通過(guò)將在源數(shù)據(jù)域上習(xí)得的度量空間遷移到新類別的方法取得了不錯(cuò)的效果。元學(xué)習(xí)方法[10,19]往往也依賴于遷移學(xué)習(xí),元學(xué)習(xí)器習(xí)得了跨任務(wù)的知識(shí)后,指導(dǎo)基礎(chǔ)學(xué)習(xí)器學(xué)習(xí)新任務(wù)的過(guò)程往往采取微調(diào)的方法。比如MAML 在每一次迭代中,元學(xué)習(xí)器都會(huì)指導(dǎo)基礎(chǔ)學(xué)習(xí)器的初始化,當(dāng)遇到一個(gè)新任務(wù)時(shí),基礎(chǔ)學(xué)習(xí)器通過(guò)微調(diào)的方法來(lái)快速適應(yīng)這個(gè)任務(wù)。
在小樣本學(xué)習(xí)中,元學(xué)習(xí)方法在解決小樣本問(wèn)題時(shí)的核心思路是用源數(shù)據(jù)域上的可遷移知識(shí)來(lái)幫助新類的學(xué)習(xí),因此記憶網(wǎng)絡(luò)常作為知識(shí)遷移的媒介應(yīng)用于元學(xué)習(xí)方法中。記憶網(wǎng)絡(luò)的一個(gè)應(yīng)用是作為注意力模塊來(lái)幫助網(wǎng)絡(luò)進(jìn)行學(xué)習(xí)。如MN(matching nets)[14]基于LSTM 提供注意力機(jī)制來(lái)嘗試挖掘查詢樣本與訓(xùn)練樣本之間的聯(lián)系,以此使得查詢樣本在嵌入空間中更具有辨識(shí)性。記憶網(wǎng)絡(luò)的另一個(gè)應(yīng)用是作為一個(gè)存儲(chǔ)信息的記憶模塊,在訓(xùn)練時(shí)將先驗(yàn)知識(shí)存儲(chǔ)到記憶模塊中,在測(cè)試時(shí)使用這些信息進(jìn)行預(yù)測(cè)。Santoro 等人[24]提出借助神經(jīng)圖靈機(jī)(NTMs)將特征信息與對(duì)應(yīng)標(biāo)簽關(guān)聯(lián)起來(lái),以此實(shí)現(xiàn)特征向量準(zhǔn)確分類。He等人[25]提出在學(xué)習(xí)過(guò)程中將大量特征和標(biāo)簽存儲(chǔ)于記憶中,當(dāng)學(xué)習(xí)新類時(shí),借助記憶中的信息對(duì)當(dāng)前任務(wù)的特征進(jìn)行增強(qiáng)。
基于記憶的遷移學(xué)習(xí)方法最核心的思想是借助一個(gè)記憶模塊來(lái)給元學(xué)習(xí)器(分類器)提供一個(gè)最優(yōu)的初始化狀態(tài),從而實(shí)現(xiàn)在新任務(wù)上的快速學(xué)習(xí)。如圖1是整體的網(wǎng)絡(luò)結(jié)構(gòu)。支持集樣本經(jīng)過(guò)特征提取器后輸出對(duì)應(yīng)的特征表示,之后在記憶模塊中,所有的特征表示會(huì)被下采樣為一個(gè)任務(wù)級(jí)別的表示,在接收到任務(wù)表示后,讀控制器輸出與當(dāng)前任務(wù)最相關(guān)的權(quán)重信息,這些權(quán)重信息可以有效地初始化分類器網(wǎng)絡(luò)中的敏感權(quán)重。最后敏感權(quán)重經(jīng)過(guò)簡(jiǎn)單微調(diào)后,與泛化權(quán)重共同作用組成分類網(wǎng)絡(luò)權(quán)重對(duì)樣本特征進(jìn)行分類。每個(gè)任務(wù)完成學(xué)習(xí)后,更新后的敏感權(quán)重信息與當(dāng)前的任務(wù)信息會(huì)進(jìn)行配對(duì)存儲(chǔ)進(jìn)記憶中。
圖1 基于記憶的遷移學(xué)習(xí)方法整體框架圖Fig.1 Overview of proposed memory-based transfer learning
對(duì)于N-wayK-shot的小樣本分類任務(wù),每個(gè)任務(wù)T由支持集與查詢集兩部分實(shí)例集合組成。其中,支持集S={(x1,1,y1),(x1,2,y1),…,(xN,K,yN)}由N個(gè)類別中每個(gè)類別采樣K個(gè)帶標(biāo)簽的實(shí)例構(gòu)成,其中xi,j表示第i個(gè)類別中第j個(gè)樣本,yi∈{ }1,2,…,N表示所屬類別,查詢集Q={q1,q2,…,qN×M}由與支持集相同的N個(gè)類別中采樣除支持集S以外的不帶標(biāo)簽的實(shí)例構(gòu)成,即S∩Q=?,qi表示第i個(gè)查詢樣本。小樣本學(xué)習(xí)的最終目標(biāo)是挖掘支持集S的先驗(yàn)知識(shí),并用其預(yù)測(cè)查詢集Q中樣本的類別。
然而,由于訓(xùn)練樣本極度缺乏,直接進(jìn)行預(yù)測(cè)的方式會(huì)面臨嚴(yán)重的過(guò)擬合風(fēng)險(xiǎn)。普遍的方案是借助一個(gè)輔助的元訓(xùn)練集Dbase學(xué)習(xí)可遷移的知識(shí)來(lái)提升網(wǎng)絡(luò)的泛化能力。Dbase由大量屬于Nbase類別的帶標(biāo)簽樣本構(gòu)成,并且與目標(biāo)小樣本任務(wù)的標(biāo)簽空間不相交,即Nbase∩Ntarget=?。與此同時(shí),使用episodes訓(xùn)練策略[14]來(lái)訓(xùn)練網(wǎng)絡(luò),這個(gè)訓(xùn)練策略被廣泛地應(yīng)用于小樣本學(xué)習(xí)的論文中,并取得了不錯(cuò)的效果。即在元訓(xùn)練過(guò)程中,對(duì)于每一個(gè)episodeT? ,從Nbase類別中采樣N個(gè)類別,每個(gè)類別采樣K個(gè)帶標(biāo)簽的樣本作為支持集S?,同樣取這些類別中除S?之外的一部分樣本作為Q?。顯然,每個(gè)訓(xùn)練episode都是在模仿N-wayK-shot的目標(biāo)小樣本任務(wù)。訓(xùn)練模型時(shí)的目標(biāo)定義為:
其中,θ是模型的參數(shù),Pθ(y|x,S?)表示樣本x屬于類別y的概率。通過(guò)大量episodes的學(xué)習(xí)之后,模型會(huì)具有很好的泛化性。
在小樣本學(xué)習(xí)中,很多方法借助微調(diào)預(yù)訓(xùn)練模型來(lái)達(dá)到快速學(xué)習(xí)的目的。然而,相對(duì)于需要微調(diào)的參數(shù)量,可用的數(shù)據(jù)量仍然過(guò)少,故微調(diào)的效果往往會(huì)受到限制,為了解決這一問(wèn)題,提出權(quán)重分解策略。具體而言,將網(wǎng)絡(luò)中部分卷積層的權(quán)重分解為泛化權(quán)重與敏感權(quán)重。在預(yù)訓(xùn)練過(guò)程中,網(wǎng)絡(luò)忽略敏感權(quán)重,僅學(xué)習(xí)泛化權(quán)重的參數(shù),在經(jīng)過(guò)大量任務(wù)的迭代后,泛化權(quán)重具備了很強(qiáng)的泛化能力。在元訓(xùn)練過(guò)程中,網(wǎng)絡(luò)凍結(jié)泛化權(quán)重,僅學(xué)習(xí)敏感權(quán)重的參數(shù),通過(guò)敏感權(quán)重與泛化權(quán)重的共同作用,來(lái)擬合特定任務(wù)。
將分類器的卷積層權(quán)重分解為泛化權(quán)重φ與敏感權(quán)重W兩部分。在預(yù)訓(xùn)練階段,為了與其他小樣本學(xué)習(xí)方法公平比較,該模型僅在小樣本學(xué)習(xí)的訓(xùn)練集上預(yù)訓(xùn)練。例如,在miniImageNet[14]數(shù)據(jù)集上,訓(xùn)練集Dbase中總共包含64 個(gè)類別,每個(gè)類別600 個(gè)樣本,模型在預(yù)訓(xùn)練時(shí)訓(xùn)練一個(gè)64類別的分類網(wǎng)絡(luò)。首先隨機(jī)初始化特征提取器參數(shù)θ以及分類器參數(shù)φ,然后通過(guò)梯度下降法對(duì)他們進(jìn)行優(yōu)化:
其中α表示學(xué)習(xí)率,l為交叉熵?fù)p失。在這個(gè)階段學(xué)習(xí)特征提取器參數(shù)θ以及分類器參數(shù)φ,同時(shí)在驗(yàn)證集上驗(yàn)證網(wǎng)絡(luò)的泛化能力。預(yù)訓(xùn)練結(jié)束后,將θ與φ固定下來(lái),作為泛化權(quán)重。在元訓(xùn)練過(guò)程中,特征提取器參數(shù)保持不變,分類器參數(shù)由泛化權(quán)重φ與敏感權(quán)重W共同組成。
在元訓(xùn)練階段,分類器凍結(jié)泛化權(quán)重,之后我們用帶標(biāo)簽的支持集S對(duì)敏感權(quán)重W進(jìn)行微調(diào),通過(guò)梯度下降法對(duì)W進(jìn)行優(yōu)化:
其中,β為學(xué)習(xí)率,l為交叉熵?fù)p失。通過(guò)對(duì)W進(jìn)行微調(diào),使得W具有任務(wù)相關(guān)性。權(quán)重更新的示意圖如圖2所示。
圖2 權(quán)重更新示意圖Fig.2 Illustration of weights update
在N-wayK-shot 的設(shè)置下,給定一個(gè)支持集S,對(duì)于一個(gè)樣本xn,k,特征提取器輸出其對(duì)應(yīng)的特征圖en,k∈RD。將任務(wù)中所有樣本的特征圖組成的特征圖e∈RN×K×D作為下采樣模塊的輸入,以此來(lái)對(duì)e進(jìn)行任務(wù)級(jí)別的下采樣:
模型在預(yù)訓(xùn)練階段,在Dbase訓(xùn)練集上進(jìn)行傳統(tǒng)的分類任務(wù),由公式(2)進(jìn)行優(yōu)化,保留在驗(yàn)證集上效果最優(yōu)的模型,將其權(quán)重固定下來(lái)作為泛化權(quán)重。在元訓(xùn)練階段,根據(jù)episodes訓(xùn)練策略,每個(gè)episode在Dbase隨機(jī)采樣出N-wayK-shot的批次,送入網(wǎng)絡(luò)進(jìn)行訓(xùn)練。在通過(guò)特征提取器后,根據(jù)公式(4)獲得任務(wù)級(jí)別的特征,再根據(jù)公式(5)、公式(6)在記憶中查找最相關(guān)任務(wù)的索引,根據(jù)索引提取出相關(guān)任務(wù)對(duì)應(yīng)的敏感權(quán)重信息,經(jīng)過(guò)reshape 后,用其來(lái)初始化分類網(wǎng)絡(luò)的敏感權(quán)重。之后根據(jù)公式(3),對(duì)分類網(wǎng)絡(luò)進(jìn)行微調(diào)。最后將任務(wù)級(jí)別的特征作為任務(wù)信息,微調(diào)后的敏感權(quán)重reshape后作為敏感權(quán)重信息,組成鍵值對(duì)存入記憶,完成一個(gè)episode的學(xué)習(xí)。
在上述的訓(xùn)練過(guò)程中,每個(gè)訓(xùn)練批次的設(shè)置都是完全匹配N-wayK-shot的元測(cè)試形式,旨在模仿小樣本的測(cè)試情景。然而,這種匹配機(jī)制意味著訓(xùn)練出來(lái)的模型只適合N-wayK-shot 的情景,很難泛化到N-wayK′-shot 的情況。因此為了增強(qiáng)網(wǎng)絡(luò)在K′-shot 上的泛化性,提出混合訓(xùn)練策略,即在元訓(xùn)練階段,每個(gè)訓(xùn)練批次由不同的shots 數(shù)組成,學(xué)習(xí)一個(gè)統(tǒng)一的結(jié)構(gòu)以此來(lái)適應(yīng)推理階段不同shots 的任務(wù)。由于在記憶模塊中,經(jīng)過(guò)任務(wù)級(jí)別的下采樣后,N-wayK-shot的樣本組成的特征表示e∈RN×K×D被壓縮為RN×C的統(tǒng)一形式,故記憶能接受任意shots數(shù)的任務(wù)。因此在執(zhí)行混合訓(xùn)練策略時(shí),網(wǎng)絡(luò)仍然是一個(gè)統(tǒng)一的模型,與輸入批次中的shots數(shù)無(wú)關(guān)。
miniImageNet數(shù)據(jù)集被廣泛用于小樣本學(xué)習(xí)中,它是ImageNet[26]數(shù)據(jù)集的子集。它包含100 個(gè)不同的類別,每個(gè)類別有600 張圖片。按照前人的工作[19]中廣泛使用的設(shè)置,同樣從數(shù)據(jù)集中劃分出64個(gè)類別作為訓(xùn)練集,16個(gè)類別作為驗(yàn)證集,剩下的20個(gè)類別作為測(cè)試集。
tieredImageNet[27]數(shù)據(jù)集同樣是ImageNet 數(shù)據(jù)集的子集,與miniImageNet 不同的是,它是一個(gè)更新更大的數(shù)據(jù)集。在數(shù)據(jù)量上,它包含608 個(gè)類別,且平均每個(gè)類別有1 281個(gè)樣本;在語(yǔ)義結(jié)構(gòu)上,它將數(shù)據(jù)集劃分成34個(gè)父級(jí)類別來(lái)確保類別之間的語(yǔ)義差距。在小樣本學(xué)習(xí)中,將高層的34 個(gè)父類劃分出20 個(gè)父類作為訓(xùn)練集(對(duì)應(yīng)351 個(gè)最終類別),6個(gè)父類作為驗(yàn)證集(對(duì)應(yīng)97 個(gè)最終類別)以及8個(gè)父類作為測(cè)試集(對(duì)應(yīng)160 個(gè)最終類別)。這種基于語(yǔ)義層次結(jié)構(gòu)的劃分方式,使得不同集合中的數(shù)據(jù)在語(yǔ)義上更加不相關(guān),更加能夠考驗(yàn)?zāi)P偷姆夯阅堋?/p>
CUB[28]數(shù)據(jù)集是關(guān)于鳥(niǎo)的細(xì)粒度分類數(shù)據(jù)集。它包含200個(gè)鳥(niǎo)的類別對(duì)應(yīng)于鳥(niǎo)的物種,共包含11 788張圖片。根據(jù)先前的設(shè)置[29],選擇100個(gè)類別作為訓(xùn)練集,50 個(gè)類別作為驗(yàn)證集,50 個(gè)類別作為測(cè)試集。對(duì)于CUB 數(shù)據(jù)集中的所有圖片,根據(jù)提供的目標(biāo)框裁剪出目標(biāo)區(qū)域作為預(yù)處理操作[27]。
最后,所有數(shù)據(jù)集的圖片都統(tǒng)一為84×84像素的尺寸再輸入到網(wǎng)絡(luò)中。
為了對(duì)比的公平性,采用通用的四層卷積網(wǎng)絡(luò)作為特征提取器,它包括4 個(gè)卷積塊,每個(gè)卷積塊由64 通道的3×3 卷積,批歸一化[30],LeakyReLU 非線性激活函數(shù)以及2×2的最大池化組成。在分類器部分,采用相似的卷積塊與一個(gè)全連接分類層組成,分類器部分的卷積層采用更多通道的卷積核來(lái)獲得更多的通道級(jí)別的信息。
采用Adam算法[31]訓(xùn)練網(wǎng)絡(luò)。在預(yù)訓(xùn)練階段初始學(xué)習(xí)率設(shè)為0.1,并且每10 個(gè)epoch 學(xué)習(xí)率下降為之前的0.1倍。在元訓(xùn)練過(guò)程中,對(duì)泛化權(quán)重進(jìn)行固定,僅學(xué)習(xí)敏感權(quán)重,此時(shí)學(xué)習(xí)率設(shè)為0.001,每個(gè)任務(wù)迭代100 次來(lái)進(jìn)行遷移學(xué)習(xí)。在元測(cè)試過(guò)程中,測(cè)試600 個(gè)epoch來(lái)計(jì)算模型的準(zhǔn)確率,在每個(gè)epoch 中,每個(gè)類別選取15個(gè)樣本組成查詢集。
在miniImageNet、tieredImageNet 以及CUB 數(shù)據(jù)集上進(jìn)行對(duì)比實(shí)驗(yàn)。為了驗(yàn)證MbTL方法的有效性,將其與一些小樣本學(xué)習(xí)方法進(jìn)行對(duì)比,包括Matching Nets[14]、MAML[10]、Meta-learner LSTM[19]、Prototypical Nets[15]、Relation Networks[16]、TPN[21]、DN4[17]、CovaMNet[18]、MNE[23]、FEAT[22]。Table1、Table2 以及Table3 是分類結(jié)果,這里模型的特征提取器都采用傳統(tǒng)的4層卷積塊,輸出的特征圖的通道數(shù)有64 和32,分別使用Conv-4-64 和Conv-4-32標(biāo)出,最優(yōu)的效果加粗標(biāo)出??梢园l(fā)現(xiàn)我們的方法明顯優(yōu)于大多數(shù)的小樣本學(xué)習(xí)方法,并且在tieredImageNet和CUB數(shù)據(jù)集上,在5-way 1-shot和5 way 5-shot設(shè)定下都取得了最好的分類效果。
表1 展示了在miniImageNet 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果。對(duì)比于TPN,本文方法在5-way 1-shot的設(shè)置下提升了接近0.8 個(gè)百分點(diǎn),而在5-way 5-shot 的設(shè)置下提升了接近2.7 個(gè)百分點(diǎn)。對(duì)比于MNE 和FEAT,在5-way 1-shot 的設(shè)置下,同樣memory-based 的MNE 取得了最優(yōu)的效果,在5-way 5-shot 的設(shè)置下,本文方法略微領(lǐng)先。表2 展示了在CUB 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果。在5-way 1-shot 的設(shè)置下,對(duì)比于CovaMNet,本文方法有著接近16 個(gè)百分點(diǎn)的顯著提升。在5-way 5-shot 的設(shè)置下,對(duì)比于DN4,本文方法仍然有著大約7.5個(gè)百分點(diǎn)的顯著提升。這些結(jié)果表明了本文方法對(duì)細(xì)粒度分類也同樣有效。表3 展示了在tieredImageNet 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果。對(duì)比于TPN 和MNE,在5-way 1-shot 的設(shè)置下,本文方法有著大約1 個(gè)百分點(diǎn)的提升,5-way 5-shot的設(shè)置下,本文方法有著3個(gè)百分點(diǎn)的明顯提升。
表1 不同方法在miniImageNet數(shù)據(jù)集上的準(zhǔn)確率Table 1 Accuracies of different methods on miniImageNet dataset
表2 不同方法在tieredImageNet數(shù)據(jù)集上的準(zhǔn)確率Table 2 Accuracies of different methods on tieredImageNet dataset
表3 不同方法在CUB數(shù)據(jù)集上的準(zhǔn)確率Table 3 Accuracies of different methods on CUB dataset
如圖3 展示了在5-way 1-shot 的設(shè)置下,模型在miniImageNet、tieredImageNet、CUB 數(shù)據(jù)集上的訓(xùn)練損失,驗(yàn)證準(zhǔn)確率以及測(cè)試準(zhǔn)確率。可以看出隨著epoch的迭代,訓(xùn)練損失總體上逐漸下降,驗(yàn)證準(zhǔn)確率對(duì)應(yīng)逐漸上升。由于經(jīng)過(guò)了預(yù)訓(xùn)練,模型處于一個(gè)較優(yōu)的狀態(tài),訓(xùn)練損失與驗(yàn)證準(zhǔn)確率的變化范圍不大。選擇在驗(yàn)證準(zhǔn)確率最大時(shí)對(duì)應(yīng)的模型來(lái)進(jìn)行效果測(cè)試,并給出在測(cè)試集上的測(cè)試效果。
圖3 訓(xùn)練損失、驗(yàn)證準(zhǔn)確率以及測(cè)試準(zhǔn)確率Fig.3 Train loss,validation accuracy and test accuracy
為了驗(yàn)證本文方法的有效性,并對(duì)權(quán)重分解、記憶模塊的效果進(jìn)行進(jìn)一步評(píng)估,在miniImageNet、tieredImageNet以及CUB上開(kāi)展了一系列消融實(shí)驗(yàn)。消融實(shí)驗(yàn)的結(jié)果如表4所示。
表4 消融實(shí)驗(yàn)Table 4 Ablation experiment單位:%
3.4.1 權(quán)重分解的效果
為了驗(yàn)證權(quán)重分解策略的效果,暫時(shí)屏蔽記憶模塊來(lái)消除其對(duì)結(jié)果影響。如表4所示,“MAML”的微調(diào)方式是對(duì)基礎(chǔ)學(xué)習(xí)器(分類器)中的所有權(quán)重參數(shù)進(jìn)行微調(diào),而“MAML+權(quán)重分解”則是先將基礎(chǔ)學(xué)習(xí)器(分類器)中權(quán)重分解為泛化權(quán)重與敏感權(quán)重,通過(guò)凍結(jié)泛化權(quán)重,僅微調(diào)敏感權(quán)重的方式來(lái)學(xué)習(xí)新任務(wù)?!癕AML+權(quán)重分解”相對(duì)于“MAML”,在miniImageNet數(shù)據(jù)集上,權(quán)重分解策略在1-shot與5-shot上分別提升了2.78個(gè)百分點(diǎn)與4.14個(gè)百分點(diǎn);在tieredImageNet 數(shù)據(jù)集上,權(quán)重分解策略在1-shot與5-shot上分別提升了2.99個(gè)百分點(diǎn)與1.07個(gè)百分點(diǎn)。由此可以得出,權(quán)重分解策略對(duì)于基礎(chǔ)學(xué)習(xí)器(分類器)的微調(diào)是有效的。
3.4.2 記憶模塊的效果
在網(wǎng)絡(luò)的學(xué)習(xí)過(guò)程中,記憶模塊存儲(chǔ)任務(wù)信息與敏感權(quán)重信息,在后續(xù)任務(wù)中,通過(guò)讀取記憶模塊中的這些先驗(yàn)知識(shí)來(lái)幫助網(wǎng)絡(luò)快速地進(jìn)行遷移學(xué)習(xí)。如表4所示,“MAML+權(quán)重分解+記憶模塊”是本文方法,相對(duì)于“MAML+權(quán)重分解”,1-shot與5-shot在miniImageNet數(shù)據(jù)集上分別提升4.82 個(gè)百分點(diǎn)與5.30 個(gè)百分點(diǎn),在tieredImageNet 數(shù)據(jù)集上分別提升6.26 個(gè)百分點(diǎn)與5.46個(gè)百分點(diǎn),在CUB 數(shù)據(jù)集上分別提升3.93 個(gè)百分點(diǎn)與6.41個(gè)百分點(diǎn)。由此可見(jiàn),記憶模塊通過(guò)給分類器提供一個(gè)更好的初始化狀態(tài)對(duì)模型的微調(diào)有積極作用。
在本文中,提出了一個(gè)基于記憶模塊的元學(xué)習(xí)方法致力于解決小樣本學(xué)習(xí)問(wèn)題。相對(duì)于傳統(tǒng)的元學(xué)習(xí)方法做了兩處改進(jìn)。首先,針對(duì)于基礎(chǔ)學(xué)習(xí)器在微調(diào)過(guò)程中出現(xiàn)的待微調(diào)參數(shù)過(guò)多,提出了一種權(quán)重分解策略,來(lái)將基礎(chǔ)學(xué)習(xí)器的權(quán)重分解為凍結(jié)權(quán)重與可學(xué)習(xí)權(quán)重,凍結(jié)權(quán)重用來(lái)保證模型的泛化能力,可學(xué)習(xí)權(quán)重用來(lái)學(xué)習(xí)新任務(wù),這樣的策略在小樣本學(xué)習(xí)中更為有效。其次,針對(duì)于基礎(chǔ)學(xué)習(xí)器的初始化狀態(tài)不佳,借助了一個(gè)記憶模塊來(lái)保存先前的任務(wù)與權(quán)重信息,根據(jù)當(dāng)前任務(wù)讀取記憶中的先驗(yàn)知識(shí)來(lái)更有效地初始化基礎(chǔ)學(xué)習(xí)器,以此幫助基礎(chǔ)學(xué)習(xí)器快速學(xué)習(xí)新任務(wù)。在miniImageNet、tieredImageNet以及CUB 數(shù)據(jù)集上與其他方法進(jìn)行效果對(duì)比,從實(shí)驗(yàn)結(jié)果上看,對(duì)比于較先進(jìn)的方法,本文方法在小樣本分類以及細(xì)粒度分類任務(wù)上取得了具有競(jìng)爭(zhēng)力的表現(xiàn)。
計(jì)算機(jī)工程與應(yīng)用2022年19期