陳 晨 王亞立 喬 宇
1(中國科學院深圳先進技術研究院 深圳 518055)
2(中國科學院大學深圳先進技術學院 深圳 518055)
近年來,深度學習(Deep Learning)[1]在計算機視覺領域的發(fā)展不斷獲得突破和成功,如其不斷刷新圖像識別、目標檢測等領域的最優(yōu)結果。深度學習是一項數(shù)據(jù)驅動的技術,其性能嚴重依賴標注數(shù)據(jù)的數(shù)量。然而,大量標注數(shù)據(jù)的收集存在多種挑戰(zhàn)。一方面,在諸如醫(yī)療、安全等特定領域,由于涉及隱私或國家安全等問題,數(shù)據(jù)的采集受到嚴格限制;另一方面,大規(guī)模數(shù)據(jù)的采集、清洗和標注需要耗費大量的人力和物力。因此,如何使用少量樣本訓練深度網(wǎng)絡,成為亟待解決的問題。受到人類從少量數(shù)據(jù)中快速學習能力的啟發(fā),Li 等[2]提出小樣本學習(Few-Shot Learning)的概念。其目的是使在已知類別(Seen Class)中訓練的分類模型,面對只有少量標注數(shù)據(jù)的未知類別(Unseen Class)依然具有較好的性能。
目前,小樣本學習已成為深度學習領域中非常重要的前沿研究問題,在醫(yī)療圖像分析等數(shù)據(jù)采集難度較大的領域具有十分廣闊的應用前景。如何在圖像小樣本的數(shù)據(jù)上訓練得到一個泛化性能較好的模型,是一個既有理論意義又具有實際應用價值的研究課題。國內外研究學者提出的常見解決方法主要有兩種:一種是基于元學習(Meta Learning)[3-6]的方法,另一種是基于度量學習(Metric Learning)[7-13]的方法。
元學習是機器學習的一個子領域,其目標為使模型學會學習。例如,Santoro 等[3]提出使用記憶增強的方法來解決小樣本識別問題。該方法利用權重更新來調節(jié)偏差,使模型學會通過將表達快速緩存到記憶中來調節(jié)輸出。由于傳統(tǒng)的梯度下降方法(如 Adagrad[14]、Adadelta[15]、Adam[16]等)需要選取眾多超參數(shù),無法在幾步內完成優(yōu)化,F(xiàn)inn 等[4]提出 MAML(Model-Agnostic Meta-Learning),通過找到一個模型參數(shù)的更加敏感狀態(tài),使模型能快速地遷移到新的任務上。Rusu 等[5]提出 LEO(Learning Embedding Optimization)方法,通過構造隱層空間解決 MAML 不能很好處理高維數(shù)據(jù)的問題。Ravi 等[6]提出一個基于長短期記憶網(wǎng)絡(Long Shot Term Memory,LSTM)[17]的元學習器(Meta Learner)模型,利用 LSTM 來代替梯度下降算法的更新規(guī)則。該方法通過學習一個通用的初始化方式,使得模型在新任務上可以從一個好的初始狀態(tài)開始訓練。以上方法的共同點是通過在訓練集上學習到的元知識,幫助模型很好地泛化到新的任務上。
度量學習則是將分類問題轉化為樣本間的相似性度量問題。其主要思路是將圖像映射到更具有區(qū)分性的特征空間中,然后通過比較待分類樣本和已標注樣本在特征空間中距離的遠近,來預測待分類樣本的類別。一個更有區(qū)分性的特征空間應該具備這樣的性質[8]:同類之間的圖像特征嵌入距離較近,而不同類之間的圖像特征嵌入距離較遠。其中,距離的度量方式包括歐氏距離和余弦距離等。
度量學習的目標是學習一個具有較好泛化能力的圖像到特征空間的映射。例如,Vinyals等[7]提出匹配網(wǎng)絡(Matching Networks)和任務片段式訓練(Task Episode Training)方法,同時使用余弦距離對樣本進行分類。Snell 等[8]提出原型網(wǎng)絡(Prototypical Networks),用原型作為一個類別在特征空間中的表示,并使用歐式距離來進行分類。Sung 等[9]提出一個可學習的度量表示——關系網(wǎng)絡(Relation Network),與簡單的歐式距離或余弦距離相比,它能夠更好地表示樣本之間的相似關系。Liu 等[10]提出傳導傳播網(wǎng)絡(Transductive Propagation Network,TPN),通過圖構造模塊來表征新類別數(shù)據(jù)的流形結構,學習如何將標簽從已標記的支持集樣本傳播到未標記的查詢集樣本。Gidaris 等[11]提出去噪自編碼器圖神經(jīng)網(wǎng)絡(Denoising Autoencoders Graph Neural Network,wDAE-GNN),基于已知類別的分類參數(shù)的數(shù)據(jù)分布,利用降噪編碼器同時重建已知類別的分類參數(shù)和小樣本未知類別的參數(shù)分布。除任務片段式訓練方法外,也可以先用整個訓練集做預訓練,再遷移到小樣本類別上。例如,Qiao 等[12]提出 PFA(Predicting Parameters from Activations),使用激活函數(shù)輸出層的倒數(shù)生成小樣本類別對應的參數(shù)。Gidaris 等[13]提出 DFVL(Dynamic Few-shot Visual Learning),通過小樣本權重生成器生成對應類別的參數(shù)。其中,小樣本權重生成器由訓練集類別權重和余弦相似度相乘得到。值得一提的是,這些基于度量學習的方法都是任務無關的。由于訓練樣本量過少,任務無關帶來的后果是模型容易過擬合已知類別,而對新類別上的查詢任務泛化能力不足[7-8]。
本文在傳統(tǒng)度量學習方法的基礎上,提出任務相關的特征嵌入模塊來抑制過擬合,引導模型充分地利用任務的信息,可根據(jù)查詢任務自適應地調整支持集樣本的特征嵌入,使得從已知類別上學習的從圖像到特征的映射在未知類別上也具有很好的泛化性。其中,任務相關的特征嵌入模塊沒有引入龐大的參數(shù)或復雜的計算,很好地控制了模型的復雜度,避免過擬合問題。同時該模塊具有很強的擴展性,可以在大部分基于度量學習的方法中方便地引入,以提高特征嵌入的可區(qū)分性。同時,本文還引入了多種正則化方法,解決數(shù)據(jù)量較少帶來的過擬合問題,提高小樣本圖像分類的性能。最終通過對不同方法的結果進行對比和分析,驗證了本文所提出方法的有效性。
小樣本學習存在兩個重要的問題:(1)已知類別和未知類別之間沒有交集,導致它們的數(shù)據(jù)分布差別很大,不能直接通過訓練分類器和微調的方式得到很好的性能;(2)未知類別只有極少量數(shù)據(jù)(每個類別僅 1 個或 5 個訓練樣本),導致分類器學習不可靠。根據(jù)小樣本圖像分類的特點,本文的小樣本圖像分類采用了任務片段式訓練與測試。具體而言,分為元訓練和元測試兩個階段。
度量學習的核心問題是尋找一種更優(yōu)的映射,將圖像嵌入到一個更有區(qū)分性的空間中。本研究希望這種映射不僅適用于已知類別,還要適用于未知類別。傳統(tǒng)基于度量的小樣本學習方法是任務無關的——支持集樣本的特征嵌入只與該樣本自身有關,而與查詢任務無關。在已知類別上訓練的模型,面對未知類別的查詢任務時,無法自適應地調整支持集樣本的特征嵌入方式。因此,模型的泛化性能會受到很大影響。
如圖 1 所示,傳統(tǒng)的任務無關的度量學習方法得到了支持集樣本和查詢集樣本的任務無關的特征嵌入。在該特征空間中,查詢集樣本和各支持集樣本間的距離很接近,因此沒有很好的區(qū)分性。而在本文中,所提出的任務相關的特征嵌入模塊,通過將支持集樣本特征嵌入和查詢集樣本特征嵌入的拼接,隨后接入卷積層,網(wǎng)絡通過學習可以根據(jù)不同的查詢集樣本自適應地改變支持集樣本的特征嵌入方式,使其面對未知類別的待分類樣本也具有較好的泛化性能。
2.2.1 Mixup
圖 1 任務相關的特征嵌入模塊Fig. 1 Task-relevant feature embedding module
Mixup[18]使用兩個不同樣本和對應標簽的凸組合來訓練深度神經(jīng)網(wǎng)絡。在傳統(tǒng)的圖像分類任務中,使用 Mixup 可以有效控制模型的復雜度,提高網(wǎng)絡泛化能力。本文將其引入小樣本分類任務中,以解決數(shù)據(jù)量嚴重缺乏帶來的過擬合問題。與傳統(tǒng)的圖像分類不同的是,在基于度量學習的小樣本分類中,查詢樣本是待分類的對象,而支持集樣本是度量分類器的組成部分。因此,Mixup 的引入需要對查詢樣本的標簽及度量分類器都做出適當?shù)恼{整。
訓練時,每次采樣兩組類別相同的 C-way K-shot 任務:
其中,S 為支持集;s 為支持集樣本圖片數(shù)據(jù);l 為 s 對應的獨熱編碼標簽;Q 為查詢集;q 為查詢集樣本圖片數(shù)據(jù);y 為 q 對應的獨熱編碼標簽。
將兩個任務的樣本打亂后進行 Mixup,可以得到新的任務 T,其中的樣本和標簽定義如下:
2.2.2 標簽平滑
本文選擇在兩個小樣本學習領域中被廣泛使用的標準數(shù)據(jù)集——miniImageNet[7]和tieredImageNet[20]上進行實驗。miniImageNet 數(shù)據(jù)集是 ImageNet[21]數(shù)據(jù)集的一個子集,包含 100 類圖片數(shù)據(jù),每類包含 600 張圖片。其中,訓練集、驗證集和測試集分別包含 64 類、16 類和 20 類。tieredImageNet 也是 ImageNet 數(shù)據(jù)集的一個子集,包含 608 類,共 779 165 張圖片。其中,訓練集、驗證集和測試集分別包含 351 類、97 類和 160 類。由此可見,tieredImageNet 的訓練集、驗證集和測試集劃分更加謹慎,從而確保每個集合中的類別差異更大。
近年來出現(xiàn)了許多基于度量學習的小樣本圖像分類方法,本文選取其中一個十分具有代表性的方法——原型網(wǎng)絡[8]作為基準方法。該方法首先使用特征提取網(wǎng)絡提取支持集樣本和查詢集樣本的特征嵌入;然后,根據(jù)支持集樣本的類別對支持集樣本的特征向量求出均值向量,以此作為該類別的特征嵌入,接著計算查詢集樣本的特征嵌入與各類別的特征嵌入的歐式距離,并使用距離的相反數(shù)作為預測的類別分數(shù);最后,在訓練階段使用帶有指數(shù)歸一化函數(shù)(Softmax)的交叉熵損失函數(shù)作為目標函數(shù)對模型進行優(yōu)化,在測試階段通過選取分數(shù)最大的類別進行預測。
在此基礎上,本文引入了類別相關的特征嵌入模塊。在使用特征提取網(wǎng)絡提取支持集樣本和查詢集樣本的特征嵌入后,將支持集樣本的特征嵌入和查詢集樣本的特征嵌入拼接起來,輸入卷積層得到新的特征嵌入。其余部分的處理與基準方法一致。
為了更加公平地與當前最好的模型進行比較,本文使用了 3 種不同的特征提取網(wǎng)絡——ConvNet[8]、ResNet[22]和 WideResNet[23]。
(1)ConvNet:ConvNet 由四層卷積模塊組成,其中每個卷積模塊由卷積層、批歸一化層(Batch Normalization Layer)、Leakly ReLU 層和最大池化層(Max Pooling Layer)順序連接組成。
(2)ResNet:深度殘差網(wǎng)絡(Deep Residual Networks)提出殘差學習的概念,通過在層與層之間引入一個恒等連接,將上一層的輸入與后面層的輸出直接相加,很好地解決了卷積神經(jīng)網(wǎng)絡(CNN)隨著層數(shù)加深而出現(xiàn)的性能退化問題。
(3)WideResNet:WideResNet(WRN)通過實驗發(fā)現(xiàn)增加網(wǎng)絡寬度(網(wǎng)絡中每一層的卷積核個數(shù))并減少網(wǎng)絡深度(網(wǎng)絡層數(shù)),可以有效地提升模型性能并提升訓練速度。
首先,使用訓練集的全部數(shù)據(jù)對特征提取網(wǎng)絡進行預訓練。然后,使用預訓練模型作為初始化參數(shù),進行元訓練。其中,使用 Adam 優(yōu)化器,初始學習率為 10-3,每 15 000 次迭代學習率減半,權重衰減為 10-6,標簽平滑的超參數(shù) α=0.2。本文在訓練中采用了隨機改變大小、隨機裁剪和隨機水平翻轉、顏色抖動(明度、對比度、飽和度和色相等變化)等數(shù)據(jù)增強方式。最后,在元測試階段,采樣 600 組 C-way K-shot 的測試任務,每組中查詢集包含 15 個樣本。并通過上述 600 組任務的準確率來計算準確率均值和 95% 置信區(qū)間。實驗代碼使用了深度學習框架 Pytorch[24],并在單張 NVIDIA GeForce GTX Titan X GPU 上運行。
本文在 miniImageNet 和 tieredImageNet 兩個數(shù)據(jù)集上完成了 5-way 1-shot 和 5-way 5-shot 兩種任務的實驗,并將結果和目前最好的方法進行了充分對比,在測試集上的分類準確率結果如表 1 和表 2 所示。由于小樣本分類的結果在相同的特征提取網(wǎng)絡下才具有可比性,下面將針對各特征提取網(wǎng)絡下的實驗結果分別進行分析。
在使用 ConvNet 作為特征提取網(wǎng)絡時,本文提出的方法在 miniImageNet 數(shù)據(jù)集上 1-shot 的分類準確率為 55.63%,5-shot 的為 71.87%;在 tieredImageNet 數(shù)據(jù)集上 1-shot 的結果為 62.32%,5-shot 的為 78.45%。與采用相同特征提取網(wǎng)絡的 PFA 相比,在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果提高了 1.10%,5-shot 的提高了 4.00%。PFA 先用整個訓練集做預訓練,然后固定前面特征提取層參數(shù),通過僅訓練最后的分類器層的參數(shù)來防止過擬合現(xiàn)象的發(fā)生。PFA 在訓練中使用了 miniImageNet 80 類的數(shù)據(jù),包括訓練集 64 類、驗證集 16 類。而本文方法僅使用 64 類訓練集的數(shù)據(jù),就明顯超過了 PFA,表明任務相關的特征嵌入模塊以及多種正則化方法有利于減輕小樣本數(shù)據(jù)下網(wǎng)絡的過擬合現(xiàn)象。
與同樣特征提取網(wǎng)絡的 TPN 方法相比,本文方法在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果提高 1.88%,5-shot 的提高 2.44%;在 tieredImageNet 數(shù)據(jù)集上 1-shot 的提高 4.79%,5-shot 的提高 5.60%。TPN 將全部無標簽數(shù)據(jù)和有標簽數(shù)據(jù)一起建立無向圖連接,通過標簽傳播的方式得到無標簽數(shù)據(jù)的標簽。本文方法在類別差異更大的 tieredImageNet 上提升效果更加明顯,說明本文方法對新類別泛化性能更好。
在使用 ResNet-12 作為特征提取網(wǎng)絡時,本文所提出方法在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果為 63.39%,5-shot 的為 77.88%;在 tieredImageNet 數(shù)據(jù)集上 1-shot 的結果為 67.81%,5-shot 的為 83.26%。與同樣特征提取網(wǎng)絡的 DFVL 相比,在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果提高 7.94%,5-shot 的提高 7.75%。DFVL 和 PFA 類似,兩種方法都是先訓練特征提取器,再固定特征提取器訓練分類器。相比于本文端到端的訓練過程,這種分兩階段的訓練不利于效果的提升。
與同樣特征提取網(wǎng)絡的元優(yōu)化網(wǎng)絡支持向量機(MetaOptNet-SVM)[25]相比,本文在miniImageNet 數(shù)據(jù)集上的 1-shot 任務結果提高 0.75%;在 tieredImageNet 數(shù)據(jù)集上 1-shot 的結果提高 1.82%,5-shot 的提高 1.70%。MetaOptNet-SVM 采用梯度下降聯(lián)合特征提取一起聯(lián)合訓練,并把最后基于距離的分類器改進成線性分類器。本文提出的任務相關的特征嵌入模塊實現(xiàn)簡單,在 miniImageNet(1-shot)和 tieredImageNet(1-shot、5-shot)上具有更好的性能。
在使用 WRN-28-10 作為特征提取網(wǎng)絡時,本文方法在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果為 66.05%,5-shot 的為 81.72%。與同樣特征提取網(wǎng)絡的 LEO 相比,在 miniImageNet 數(shù)據(jù)集上 1-shot 的結果提高 4.29%,5-shot 的提高 4.13%。LEO 通過構造隱層空間把圖像編碼得到隱空間向量再解碼得到分類參數(shù),并把分類參數(shù)和隱空間向量使用 MAML 的方式進行訓練。這種訓練方法存在一定的不可控性,如梯度下降的步數(shù)很難選取,其中步數(shù)過多容易過擬合,步數(shù)過少則效果不夠好。本文使用基于度量的方法,在得到任務相關的特征嵌入后,通過度量可以很簡單地使用距離表達樣本的相似性,最終得到的分類器也具有足夠強的自適應能力。
本文方法在 tieredImageNet 上 1-shot 的結果為 68.96%,5-shot 的為 84.17%。與同樣特征提取網(wǎng)絡的 wDAE-GNN 方法相比,在 miniImageNet 數(shù)據(jù)集 1-shot 上的結果提高 0.78%,5-shot 上提高 1.08%。wDAE-GNN 的缺點在于,降噪編碼器需要同時重建已知類別的參數(shù)分布和小樣本未知類別的參數(shù)分布,而模型并未考慮到已知類別和未知類別的任務差異,因此在重建過程中小樣本未知類別的參數(shù)分布和真實的參數(shù)分布會產生偏差。從實驗結果對比可以發(fā)現(xiàn),本文提出的任務相關的特征嵌入模塊可以有效解決上述問題,從而顯著提高模型性能。
綜上,在特征提取網(wǎng)絡相同的前提下,本文方法在 miniImageNet 和 tieredImageNet 上的結果都超過目前的其他方法。這驗證了本文所提出的任務相關的特征嵌入模塊以及多種正則化方法在提升模型的泛化性能方面的有效性。
為了進一步探究任務相關的特征嵌入模塊與各種正則化方法本身的效果,本文還使用 ResNet-12 作為特征提取網(wǎng)絡,在 miniImageNet 數(shù)據(jù)集上進行充分的消融實驗,結果如表 3 和表 4 所示。
從表 3 可以看出,不包含任何本文所提出
方法的基準實驗 1-shot 和 5-shot 的結果分別為 55.27%、71.45%,明顯較低,說明在這種情況下模型存在較為嚴重的過擬合問題。數(shù)據(jù)增強模塊的引入給 1-shot 和 5-shot 任務分別帶來了 4.55% 和 3.78% 的性能提升。這表明,本文針對小樣本學習訓練數(shù)據(jù)不足的問題所采用的隨機改變大小、隨機裁剪和隨機水平翻轉、顏色抖動等數(shù)據(jù)增強方式,有效擴大了訓練數(shù)據(jù)的規(guī)模,同時增加了訓練樣本的多樣性,提高模型網(wǎng)絡對同一類樣本不停變換的適應性,從而使模型學習到更加本質的特征??梢姡瑪?shù)據(jù)增強從數(shù)據(jù)層面解決過擬合問題,提高模型的泛化能力。標簽平滑的策略將 1-shot 和 5-shot 的結果分別提升了 0.92% 和 0.63%??梢钥闯?,相比于獨熱編碼,平滑的標簽可以有效提高模型的泛化能力。Mixup 訓練方式的引入進一步將準確率提高 1.34% 和 1.05%,證明使用訓練樣本的線性插值進行訓練,可以約束模型的復雜度,減輕數(shù)據(jù)稀少所帶來的過擬合問題。任務相關的特征嵌入模塊將結果提升了 1.31% 和 0.97%,最終使 1-shot 和 5-shot 的準確率達到 63.39% 和 77.88%。這表明根據(jù)查詢任務主動調整支持集樣本的特征嵌入,可以幫助模型使用在已知類別上學到的元知識,快速地遷移到新的任務中。
表 1 MiniImageNet 數(shù)據(jù)集上的結果對比Table 1 Comparison with SOTA (state of the art) on miniImageNet dataset
表 2 TieredImageNet 數(shù)據(jù)集上的結果對比Table 2 Comparison with SOTA on tieredImageNet dataset
為了更進一步探究每個部分單獨剝離后對網(wǎng)絡性能的影響,本文進行了充分的對比實驗,結果如表 4 所示。首先,探究了數(shù)據(jù)增強、標簽平滑以及 Mixup 三種不同的正則化方式對網(wǎng)絡性能的影響。從表 3 可以看出,相比于不包含任何本文所提出方法的基準實驗,數(shù)據(jù)增強模塊給 1-shot 和 5-shot 任務分別帶來了 4.55% 和 3.78% 的性能提升。而在表 4 的實驗結果中,單獨剝離數(shù)據(jù)增強模塊在 1-shot 和 5-shot 任務上性能分別降低 1.25% 和 1.03%。顯然,單獨剝離數(shù)據(jù)增強模塊造成的性能損失程度并沒有表 3 實驗中在基準模型上引入該模塊所帶來的增益程度高。這說明,標簽平滑和 Mixup 兩個模塊起到了明顯的正則化作用,彌補了剝離數(shù)據(jù)增強模塊所帶來的負面影響。類似地,單獨剝離標簽平滑模塊會造成模型在 1-shot 和 5-shot 任務上性能分別降低 0.66% 和 0.33%,單獨剝離 Mixup 模塊會造成模型在 1-shot 和 5-shot 任務上性能分別降低 0.97% 和 0.74%。可以發(fā)現(xiàn),當缺失某一種正則化方式時,網(wǎng)絡性能并沒有表 3 中增加模塊引起的變化那樣劇烈,說明這些正則化方式之間存在一定的互補性。其次,進行了單獨剝離任務相關的特征嵌入模塊的實驗。正如對表 3 中結果的分析,去除任務相關的特征嵌入模塊使得模型在 1-shot 和 5-shot 任務上的結果降低了 1.31% 和 0.97%。這表明任務相關的特征嵌入模塊對于網(wǎng)絡的泛化性能至關重要。它可以有效地引導模型在已知類別上學習到有用的元知識,使其在新的任務中可以快速利用元知識對支持集樣本特征進行調整。
表 3 依次引入各模塊對性能的影響Table 3 Effect on performance after introduce each module in sequence
表 4 剝離各模塊對性能的影響Table 4 Effect on performance after remove each module
針對現(xiàn)有基于度量的圖像小樣本深度學習方法與任務無關,容易造成模型過擬合已知類別,而對新類別上的查詢任務泛化能力不足等問題。本文提出一種新穎的任務相關的小樣本深度學習方法,幫助模型根據(jù)查詢任務,自適應地調整支持集樣本的特征,從而有效形成任務相關的度量分類器。同時,本文引入多種正則化方法,進一步地提升了模型的泛化性能。實驗結果表明,這些方法可以有效地解決網(wǎng)絡的過擬合問題,提升小樣本圖像分類的準確率。在實際的數(shù)據(jù)樣本中,除了少量已標注好的樣本外,還有很多未標注的樣本。這是因為在實際的醫(yī)療、安全等領域中,完全標注需要很大的人力成本。因此,小樣本半監(jiān)督分類具有很強的應用價值和實際價值,未來可以將本文提出的方法擴展到半監(jiān)督的小樣本分類問題中。