鄭欣悅 黃永輝
(中國科學院國家空間科學中心復雜航天系統(tǒng)電子信息技術(shù)重點實驗室 北京 100190)(中國科學院大學 北京 100049)
近年來,人工智能技術(shù)研究飛速發(fā)展,深度學習算法已在圖像識別領(lǐng)域取得了突破性的進展,但算法也逐漸顯露出泛化能力差、所需訓練數(shù)據(jù)大等缺點。目前,以CNN為基礎(chǔ)的圖像識別方法通常需要海量的訓練數(shù)據(jù)和充足的迭代次數(shù),才可對特定的圖像類別進行精準的分類。然而,實際應(yīng)用中研究者常面臨數(shù)據(jù)稀缺的情況,比如罕見物種圖片、珍貴的醫(yī)療診斷圖片、稀有遙感圖像等,采集這些數(shù)據(jù)的難度大且成本高。而少量的樣本通常不足以訓練出一個較好的深度神經(jīng)網(wǎng)絡(luò)。因此,如何實現(xiàn)小樣本圖像識別成為了計算機視覺領(lǐng)域的重要研究方向。
針對小樣本學習問題,深度學習領(lǐng)域存在著許多不同的解決方案,其中元學習方法取得了尤為顯著的成效。元學習(Meta-learning)是指導分類器學會如何學習的過程。元學習器在有限的樣例中對結(jié)構(gòu)基礎(chǔ)層次和參數(shù)空間進行優(yōu)化,以獲得跨任務(wù)泛化性能[1],具備小樣本學習的能力。訓練完成的元學習器可以僅根據(jù)1至5個輸入-輸出樣例對新的測試樣本進行分類。
目前元學習的方法可以歸類為以下幾種:基于記憶存儲的方法[2,9]通過權(quán)重更新來調(diào)整偏差,并不斷地從記憶中學習。Santoro等[2-3]利用神經(jīng)圖靈機引入的外部存儲器來實現(xiàn)短期記憶并在標簽和輸入圖像之間建立連接,使輸入能夠與存儲器中的相關(guān)圖像進行比較,以實現(xiàn)更好的預測?;谔荻鹊姆椒╗4-5]通常通過訓練額外的網(wǎng)絡(luò)來預測分類器更新策略,如Larochelle等[5]提出訓練LSTM優(yōu)化器以學習分類器網(wǎng)絡(luò)的參數(shù)優(yōu)化規(guī)則。關(guān)系網(wǎng)絡(luò)[6]和匹配網(wǎng)絡(luò)[7]采用度量學習的思想,不再使用人工設(shè)計的指標,而是完全利用神經(jīng)網(wǎng)絡(luò)來學習深度距離度量。Finn等[8]提出了一種稱為MAML模型無關(guān)的元學習方法。該方法的基本思想是同時啟動多個任務(wù),然后獲取不同學習任務(wù)的合成梯度方向來更新神經(jīng)網(wǎng)絡(luò)。這樣的優(yōu)化方式能找到最適合網(wǎng)絡(luò)的初始化位置,這里的初始化位置被定義為:僅通過幾個小樣本的訓練可以調(diào)整到最好表現(xiàn)。Reptile[10]是OpenAI提出的簡化版MAML算法,MAML需要在反向傳播中計算二階導數(shù),而Reptile只需要計算一階導數(shù),消耗更少的計算資源且更易于實現(xiàn)。
雖然上述方法取得了令人矚目的成果,但普遍存在兩個缺陷:一是算法引入人為設(shè)計的規(guī)則來約束學習;二是需要更多額外的存儲空間對經(jīng)驗進行存儲,并且沒有提供將知識轉(zhuǎn)移到其他任務(wù)的理論手段。因此本文提出一種結(jié)合表征學習和注意力機制[18-19]的元學習方法VAE-ATTN。表征模塊利用過去的知識,將高維圖像數(shù)據(jù)表達為有意義的高級表征;注意力模塊引導學習器關(guān)注關(guān)鍵特征,以快速適應(yīng)新的學習任務(wù)。
VAE-ATTN算法首先運用變分自編碼器VAE[11-12]通過無監(jiān)督學習方法獲取各個任務(wù)內(nèi)部共享的特征。VAE的編碼器保留預訓練后的網(wǎng)絡(luò)模型參數(shù),將提取的低維高級表征遷移到不同的識別任務(wù)中。同時,在通道維度引入注意力機制,通過計算概率分布選擇性加強對當前學習任務(wù)更重要的特征。本文使用Reptile元學習算法作為基準算法。實驗結(jié)果表明,VAE-ATTN算法整體性能優(yōu)于MAML、MatchingNets、Meta-LSTM等對比算法,驗證了有效的表征學習和注意力機制的結(jié)合能獲得更加精準的小樣本分類結(jié)果。
本文使用變分自編碼器進行表征學習,表征學習的目標是從數(shù)據(jù)中自動學習到從原始數(shù)據(jù)到數(shù)據(jù)表征之間的映射。VAE作為深度神經(jīng)網(wǎng)絡(luò),由編碼器和解碼器構(gòu)成。如圖1所示,VAE本質(zhì)是提取數(shù)據(jù)的隱特征,構(gòu)建從隱特征到生成目標的模型。編碼器從原始數(shù)據(jù)中提取潛在的合理變量,再對編碼結(jié)果加上高斯噪聲加以約束,使之成為服從高斯分布的隱含特征。解碼器構(gòu)建的模型將隱特征映射到重新生成的概率分布中,重構(gòu)的分布需盡量與原始分布相同。
圖1 VAE的工作原理
網(wǎng)絡(luò)有兩個組件:具有參數(shù)φ的編碼器網(wǎng)絡(luò)E和具有參數(shù)θ的解碼器D,其損失函數(shù)為:
L(φ,θ,x)=Eqφ(z|x)[logpθ(x|z)]-DKL(qφ(z|x)‖pθ(z))
(1)
式中:qφ(z|x)表示從數(shù)據(jù)空間到隱含空間的編碼器;pθ(x|z)表示從隱含空間到數(shù)據(jù)空間的解碼器。
損失函數(shù)由兩方面構(gòu)成:式(1)第一項為重構(gòu)誤差,驅(qū)使重構(gòu)的pθ(x|z)分布更接近于輸入分布pθ(x);第二項旨在減小KL散度,驅(qū)使qφ(z|x)更接近于先驗分布pθ(z)。為了實現(xiàn)這種重構(gòu),VAE將捕捉到可以代表原始輸入數(shù)據(jù)的最重要的特征因素。
特別地,我們嘗試用VAE的變體β-VAE[13-15]進行實驗。β-VAE引入解纏性先驗[21],假設(shè)數(shù)據(jù)是基于互相獨立的因素生成的,因此可以用表征中不同的獨立變量表示這些因素。該解纏性先驗可促進編碼器學習數(shù)據(jù)簡潔的抽象表示,從而用于各種下游任務(wù)并提升樣本效率。
如式(2)所示,β-VAE引入了一個可調(diào)節(jié)的超參數(shù)β,它可控制隱變量的維度以及重建精度之間的平衡,同時高斯先驗的各向同性性質(zhì)也給學習的后驗帶來了隱形的約束。β變化會改變訓練期間學習程度,從而鼓勵不同的學習表征,實驗中需要調(diào)整的值以促進使用解纏后的表征。
L(φ,θ,x)=Eqφ(z|x)[logpθ(x|z)]-βDKL(qφ(z|x)‖pθ(z))
(2)
VAE的無監(jiān)督學習階段需要編碼器對輸入數(shù)據(jù)降維,并從中提取通用而高級的表征,以適用于小樣本學習中不同圖像類別的一系列任務(wù)分布。從元學習的角度處理這個問題,將目標定義為一個有效的學習過程,可以從無標記數(shù)據(jù)轉(zhuǎn)移到少標記樣本的任務(wù)。
Bengio等[21]提出具有適合特定任務(wù)和數(shù)據(jù)域的表征可以顯著提高訓練模型的學習成功率和穩(wěn)健性。因此,本文對VAE提取的高級表征構(gòu)建注意力機制,使元學習器能在全局信息中關(guān)注更有利于當前學習任務(wù)的目標表征。自注意機制與人類視覺注意力機制起著類似的作用,從大量的信息中篩選出部分關(guān)鍵的信息,并聚焦到這些重要的信息上。
圖2闡述了注意力模型的內(nèi)部結(jié)構(gòu)。該模塊通過分析輸入數(shù)據(jù)的總特征,捕獲通道間依賴關(guān)系,預測通道重要性,以此選擇性地強調(diào)某些特征。
圖2 注意力模型的網(wǎng)絡(luò)結(jié)構(gòu)以及相應(yīng)特征的維度
根據(jù)預訓練過編碼器產(chǎn)生的隱特征γ構(gòu)建注意力模塊的輸入,γ∈Rb×h×w×c,其中b為批大小(Batch size),h和w為特征圖的長和寬,c是通道數(shù)。由式(3)-式(6)所示,Q和K由輸入特征γ通過1×1卷積的跨通道信息整合而得的新的特征圖,并將維度變換為Rx×c,其中x=h×w,接著在Q和K的轉(zhuǎn)置之間執(zhí)行矩陣乘法,最后使用softmax函數(shù)進行歸一化,得到維度為c×c注意力概率分布αji。這樣設(shè)計的意義在于計算γ的每個通道數(shù)之間的影響力權(quán)重,可以突出關(guān)鍵特征圖的作用,減少冗余特征對整體分類性能的影響。
Q=reshape(FCNN(γ;θ1))
(3)
K=reshape(FCNN(γ;θ2))
(4)
V=reshape(γ)
(5)
(6)
最后,將權(quán)重系數(shù)αij與原始特征進行加權(quán)求和,再用尺度系數(shù)β加以調(diào)整,即可獲得辨別性高的特征表達Oj:
(7)
其中:β初始化為0,在學習的過程中逐漸分配到更大的權(quán)重。
該注意力模塊能自適應(yīng)地整合局部特征并明確全局依賴,使得元學習器能注意到更有用的特征,在樣本匱乏的情況下出色地完成分類工作。
針對傳統(tǒng)深度學習方法的局限性,VAE-ATTN提供了很好的解決方案。VAE-ATTN提出通過預訓練VAE學習任務(wù)高級表征,混合使用注意力機制快速運用關(guān)鍵表征的方法,最大化從少量樣本中獲取的有效信息。
方法分為兩個階段,第一階段為表征模塊的預訓練。算法使用深度生成模型VAE構(gòu)建一個提供數(shù)據(jù)嵌入或特征表征的模型。預訓練集由大規(guī)模圖像分類數(shù)據(jù)集ImageNet上隨機抽取的150個類組成,這些類別和元數(shù)據(jù)集中的類別沒有重疊。VAE從預訓練集中學習各個圖像類別共享的特征子集。特別地,實驗嘗試使用β-VAE作為表征模塊,相比于線性嵌入或從常規(guī)變分自編碼器獲得的特征,β-VAE能夠提取解纏的特征,具有更加有效的表征能力。
第二階段為元學習階段。將預訓練完成的VAE編碼器,作為特征提取器遷移至新的識別任務(wù)中。VAE輸出的通道響應(yīng)彼此關(guān)聯(lián),每個通道映射可以被視作特定于類別的響應(yīng)。因此對VAE的輸出特征引入注意力機制,利用通道映射之間的相互依賴性,選擇性地強調(diào)相互依賴的特征映射,并改進特定類別的特征表示。本文使用的基準元學習算法為模型無關(guān)的Reptile元學習方法,Reptile掌握任務(wù)分布規(guī)律,從特征空間和參數(shù)空間對元學習器進行聯(lián)合優(yōu)化。
圖3為基于VAE和注意力機制的元學習圖像分類架構(gòu)。編碼器是深度為4的卷積網(wǎng)絡(luò),解碼器由4層反卷積構(gòu)成。對編碼器提取的特征輸入注意力模塊,進行特征加強。最后通過由全連接層和Softmax層組成的分類器,得到圖像分類成果。這樣的結(jié)構(gòu)即保留了抽象的圖像特征,又為在面臨新任務(wù)的學習時保留了調(diào)整的余地。算法運行的偽代碼如算法1所示。
圖3 VAE-ATTN圖像分類框架
算法1VAE-ATTN元學習算法
1 預訓練VAE模型,重復步驟1)-步驟2)直至圖像重構(gòu)誤差小于σ:
1) 從預訓練集中采樣n張圖片P(0)~P(n-1);
2) 在每幅圖像上執(zhí)行隨機梯度下降,優(yōu)化網(wǎng)絡(luò)編碼器參數(shù)φ和解碼器參數(shù)θ。
2 將預訓練好的編碼器的參數(shù)值φ固定,連接Attention模塊。
3 Attention模塊參數(shù)A在元數(shù)據(jù)集上通過Reptile算法進行訓練以學會強調(diào)關(guān)鍵的特征圖,步驟1)-步驟3)預定義的J次:
1) 從元數(shù)據(jù)集中采樣n個任務(wù)τ(0)~τ(n-1);
2) 在每個任務(wù)τi上執(zhí)行連續(xù)k步梯度下降,計算權(quán)值Wi=SGD(Lτi,k,A);
4 在測試集上驗證模型,獲得最終準確率。
Reptile[10]作為基準元學習算法,本質(zhì)上是通過不斷地采樣不同類別的任務(wù),在任務(wù)層面實現(xiàn)知識的泛化。算法的優(yōu)化目標如下:
(8)
為了驗證基于VAE和注意力機制的元學習方法的有效性,實驗選取兩個重要的基準數(shù)據(jù)集Mini-ImageNet和Omniglot進行實驗,并將測試結(jié)果與其他元學習方法進行比較。Omniglot[16]是Lake等提出的語言文字數(shù)據(jù)集,該數(shù)據(jù)集包含50種文字,1 623類手寫字符,每一類字符僅擁有20個樣本,且這些樣本均為不同的人繪制而成。Mini-ImageNet[7]數(shù)據(jù)集由DeepMind于2016年提出,是計算機視覺領(lǐng)域的重要基準數(shù)據(jù)集,它通過從ImageNet隨機抽樣100個類并為每個類選擇600個樣本創(chuàng)建而成。其中:訓練集包含64個類別,共計38 400幅圖像;測試集包含20個類別,共計12 000幅圖像;驗證集包含16個類,9 600張圖像。
預訓練階段:變分自編碼器從原始的,未標記的預訓練集數(shù)據(jù)中進行學習。從ImageNet中隨機抽取150類,每類600張圖片組成預訓練集。預訓練集沒有與Mini-ImageNet數(shù)據(jù)集中的類別重疊。在β-VAE訓練階段,本文采用Adam優(yōu)化器,固定學習率為0.001。編碼器模型運用4層CNN卷積層,每層使用64個大小為3×3的卷積核,輸出為100維的隱變量。損失函數(shù)一方面通過交叉熵來度量圖片的重構(gòu)誤差,另一方面,通過KL散度來度量隱變量的分布和單位高斯分布的差異。根據(jù)損失函數(shù)的收斂特性,本文選取的批大小為32,以獲得隨機性避免陷入局部最優(yōu)化。
元學習階段:網(wǎng)絡(luò)運用訓練集中有標記的,訓練集數(shù)據(jù)樣本進行學習。在預訓練階段之后,β-VAE已經(jīng)從預訓練集中學習了低維的高級特征,元學習器只需要通過快速調(diào)整其注意力模塊來學習如何適應(yīng)新的學習任務(wù)。網(wǎng)絡(luò)使用Reptile算法對注意力模塊進行2萬次的訓練迭代,每次連續(xù)計算8步梯度下降來更新網(wǎng)絡(luò)參數(shù),詳細超參設(shè)置見表1。
表1 元學習參數(shù)表
實驗考慮解決小樣本分類中K-樣本,N-類別[7]學習問題。對于K-樣本,N-類別(K-shot,N-way)分類的每個任務(wù),學習器訓練N個相關(guān)類,每個類都有K個例子,首先從元數(shù)據(jù)集中采樣N個類,為每個類選擇K+1個樣本。然后,將這些示例拆分為訓練和測試集,其中訓練集包含每個類的K個示例,測試集包含剩余樣本。以5-樣本,5-類別分類為例,實驗中共抽取30個樣例,使用其中25個樣本5(圖像)×5(類)訓練學習器并使用剩余的示例來測試模型。
4.2.1β-VAE的重構(gòu)分析
對于無監(jiān)督學習階段,實驗考察了β參數(shù)對提取解纏特征的影響。實驗發(fā)現(xiàn)β=8是對于最終學習器進行小樣本分類的最合適的參數(shù)值,實驗中大約一半的隱變量已經(jīng)收斂到單位高斯先驗。如圖4所示,(a)為測試圖片,(b)為β=8時的β-VAE重構(gòu)圖像。從圖像重建的質(zhì)量上分析,由于隱變量的維度受限,良好的解纏表征可能會導致模糊的重建[12]。但解纏表征例如旋轉(zhuǎn)、大小、位置等有助于加速后期元學習階段的學習,幫助注意力模塊理解不同任務(wù)之間的共享特征,對提升小樣本分類性能有更明顯的成效。
(a) 測試圖像
(b) β-VAE的輸出(β=8)圖4 測試圖像與重構(gòu)圖像
4.2.2注意力影響可視化分析
該部分實驗成果可視化了注意力機制給小樣本分類帶來的影響。實驗使用t-SNE算法[20]將網(wǎng)絡(luò)輸出的特征值降維并投影至2維空間。圖5是Mini-ImageNet實驗中測試場景的特征可視化圖,(a)為特征在進入注意力模塊之前的前期特征,(b)為經(jīng)過注意力機制增強之后的特征。為使圖像表述更加清晰,t-SNE實驗中共采樣3種類別,每種類別200幅圖像進行降維,圖中的3種標記符號分別代表3個不同的類別。
(a) (b)圖5 特征通過t-SNE投影至2維空間的可視化結(jié)果
可以看出,在經(jīng)過注意力模塊的特征改進之后,不同圖像類別之間的分布差異更加明顯,類內(nèi)距離的標準差縮小,而類間距標準差增大。實驗結(jié)果表明,注意力機制可以捕獲高級特征里的關(guān)鍵特征,有助于元學習器更好地區(qū)分不同類別的圖像。
4.2.3小樣本圖像分類結(jié)果
將VAE-ATTN元學習方法與現(xiàn)有元學習方法相比較,表2及表3展示了基礎(chǔ)設(shè)置和直推設(shè)置的實驗成果。在直推模式中,元學習器允許同時擁有標簽訓練樣本和無標簽測試樣本,訓練后的模型一次性對測試集中的所有樣本進行分類,因此允許信息通過批量標準化在測試樣本之間共享[9]。也就是說,測試樣本的類標簽預測過程會受到彼此的影響,不再是相互獨立的。表2與表3中,Y表示運用了直推設(shè)置,N表示未運用直推設(shè)置。觀察實驗結(jié)果發(fā)現(xiàn),使用直推設(shè)置的分類結(jié)果明顯優(yōu)于未使用該設(shè)置的結(jié)果。
表2 Mini-ImageNet 小樣本分類結(jié)果 %
表3 Omniglot小樣本分類結(jié)果 %
續(xù)表3 %
從表2中可以看出,在Mini-ImageNet上,本文提出的算法超過了當前性能優(yōu)異的元學習算法,如MAML、MatchingNets、Meta-LSTM等。在5-樣本,5-類別以及1-樣本,5-類別的測試場景中分別獲得72.5%和53.5%的準確率,顯著超越原始Reptile算法的分類性能。由表3可知,在Omniglot數(shù)據(jù)集上,β-VAE在5-樣本,20-類別以及1-樣本,20-類別的測試場景中,取得了98.8%和96.5%的高分類準確率。實驗結(jié)果說明基于表征學習和注意力機制的方法改善了元學習器,證明了VAE-ATTN算法的合理性。
圖6是Mini-ImageNet中5-樣本,5-類別的直推實驗的分類準確率曲線圖??梢钥闯?,VAE-ATTN算法均超出Reptile基準元學習算法,且運用β-VAE進行預訓練的分類效果也優(yōu)于常規(guī)VAE的訓練效果。這一結(jié)果說明β-VAE提取的解纏表征加速元學習器結(jié)構(gòu)化地理解多樣的任務(wù),實現(xiàn)更高的小樣本分類準確率。
圖6 Mini-ImageNet實驗分類準確率對比
小樣本圖像識別在人工智能領(lǐng)域是復雜且具有挑戰(zhàn)性的研究方向,極具探索價值和意義。本文通過分析以往元學習方法存在的問題,提出結(jié)合表征學習和注意力機制的新元學習方法VAE-ATTN。算法運用β-VAE學習的高級的解纏表征,并通過注意力機制增強重要的信息并抑制冗余的信息,從而引導元學習器進行小樣本學習。本文算法在Mini-ImageNet和Omniglot數(shù)據(jù)集上的小樣本學習測試中均展現(xiàn)了良好的性能,表明了算法的有效性和可行性。
在后續(xù)工作中,我們將考慮更具泛化性的元學習方法,目標是提取可跨任務(wù)或遠距離遷移的特征,使得小樣本學習能根據(jù)更充分的先驗知識進行新任務(wù)的快速學習。