吳鵬翔,李凡長
(蘇州大學(xué)計(jì)算機(jī)科學(xué)與技術(shù)學(xué)院,江蘇蘇州 215006)
隨著計(jì)算設(shè)備并行計(jì)算性能的大幅提升,以及近年來深度神經(jīng)網(wǎng)絡(luò)在各個(gè)領(lǐng)域不斷取得重大突破,由深度神經(jīng)網(wǎng)絡(luò)模型衍生而來的多個(gè)機(jī)器學(xué)習(xí)新領(lǐng)域逐漸成型,如強(qiáng)化學(xué)習(xí)、深度監(jiān)督學(xué)習(xí)等[1-2]。在大量訓(xùn)練數(shù)據(jù)的加持下,深度神經(jīng)網(wǎng)絡(luò)技術(shù)已經(jīng)在機(jī)器翻譯、機(jī)器人控制、大數(shù)據(jù)分析、智能推送、模式識(shí)別等方面得到了廣泛應(yīng)用[3-4]。深度學(xué)習(xí)在完成這些任務(wù)時(shí)需要在大量數(shù)據(jù)上進(jìn)行訓(xùn)練才能擬合出一個(gè)好的結(jié)果,一旦需要被識(shí)別物體類別不在訓(xùn)練集中,便無法進(jìn)行正確識(shí)別。但是在實(shí)際的許多任務(wù)中,要求在少量數(shù)據(jù)上進(jìn)行快速學(xué)習(xí)和適應(yīng)[5]。
元學(xué)習(xí)的提出為上述問題提供了一個(gè)解決方案,其目的是解決傳統(tǒng)神經(jīng)網(wǎng)絡(luò)模型泛化能力不足、對(duì)新種類任務(wù)適應(yīng)性較差的問題。快速學(xué)習(xí)的能力是人類區(qū)別于人工智能的一個(gè)關(guān)鍵特征[6],人類能夠有效地利用以前的知識(shí)和經(jīng)驗(yàn)來快速學(xué)習(xí)新的技能。元學(xué)習(xí)的訓(xùn)練和測(cè)試可類比為人類在掌握一些基本技能后快速學(xué)習(xí)并適應(yīng)新的任務(wù)[7]。例如:人類可以根據(jù)一張從未見過的動(dòng)物的照片辨認(rèn)出該動(dòng)物,而不是需要大量該動(dòng)物的照片才能辨認(rèn)。人類在幼兒階段掌握的對(duì)世界的大量基礎(chǔ)知識(shí)和對(duì)行為模式的認(rèn)知基礎(chǔ)便對(duì)應(yīng)元學(xué)習(xí)中的“元”概念[8-9]。元學(xué)習(xí)的最終目標(biāo)是實(shí)現(xiàn)擁有類似人類學(xué)習(xí)能力的強(qiáng)人工智能,這在當(dāng)前階段體現(xiàn)為對(duì)新數(shù)據(jù)集的快速適應(yīng)以得到較高的準(zhǔn)確度,因此,目前元學(xué)習(xí)目標(biāo)主要表現(xiàn)為提高泛化性能、獲取好的初始參數(shù),以及通過少量計(jì)算和新訓(xùn)練數(shù)據(jù)即可在模型上實(shí)現(xiàn)和海量訓(xùn)練數(shù)據(jù)一樣的識(shí)別準(zhǔn)確度[10]。受當(dāng)前計(jì)算資源與算法能力限制,元學(xué)習(xí)往往以小樣本學(xué)習(xí)以及對(duì)新任務(wù)的快速適應(yīng)作為切入點(diǎn),因此,當(dāng)前研究也多以在小樣本數(shù)據(jù)集上的識(shí)別準(zhǔn)確率作為實(shí)驗(yàn)衡量標(biāo)準(zhǔn)[11]。
基于度量的元學(xué)習(xí)方法是一種可行的元學(xué)習(xí)方法。KOCH等于2015年提出了一種用于解決單樣本學(xué)習(xí)圖像分類問題的方法:孿生網(wǎng)絡(luò)(Siamese network)[12],通過訓(xùn)練集學(xué)習(xí)一個(gè)卷積孿生網(wǎng)絡(luò),利用該網(wǎng)絡(luò)計(jì)算待測(cè)試圖像與所有單標(biāo)注樣本的相似度,相似度最高的單標(biāo)注樣本所對(duì)應(yīng)的類別即是待測(cè)試圖像的類別。VINYALS 于2016 年提出了匹配網(wǎng)絡(luò)模型[13],其主要?jiǎng)?chuàng)新體現(xiàn)在建模過程和訓(xùn)練過程。對(duì)于建模過程的創(chuàng)新,該文通過設(shè)計(jì)基于記憶和注意力機(jī)制的匹配網(wǎng)絡(luò),使得模型能夠?qū)⑴c訓(xùn)練的樣本進(jìn)行快速學(xué)習(xí)。對(duì)于訓(xùn)練過程的創(chuàng)新,該文基于傳統(tǒng)機(jī)器學(xué)習(xí)的一個(gè)重要原則,即訓(xùn)練和測(cè)試應(yīng)在同樣條件下進(jìn)行,提出在訓(xùn)練時(shí)每次僅使用每一類任務(wù)的少量樣本參與網(wǎng)絡(luò)的訓(xùn)練,與測(cè)試過程保持一致。SNELL 于2017 年提出了原型網(wǎng)絡(luò)[14],該網(wǎng)絡(luò)模型基于一個(gè)基本假設(shè),即在數(shù)據(jù)集中,對(duì)于每種不同的類型都存在一個(gè)原型點(diǎn)。數(shù)據(jù)集中距離該原型點(diǎn)越近的樣本,其標(biāo)簽與該原型點(diǎn)對(duì)應(yīng)的標(biāo)簽相同的概率就越大。文獻(xiàn)[15]提出了由嵌入模塊和關(guān)系模塊組成的關(guān)系網(wǎng)絡(luò),其中嵌入模塊用于提取輸入圖像的特征,關(guān)系模塊用于得到輸入特征的相似度。
傳統(tǒng)基于度量的元學(xué)習(xí)算法采用卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)提取特征,但是元學(xué)習(xí)問題中的某些樣本圖片特征不僅具有平移對(duì)稱性[16],而且還具有旋轉(zhuǎn)對(duì)稱性和鏡像對(duì)稱性[17],但是CNN 只具有平移不變性,不存在對(duì)后兩者的不變性,這就使得傳統(tǒng)的元學(xué)習(xí)算法不能利用具有對(duì)稱性的特征。常用的解決方法是數(shù)據(jù)增強(qiáng)[18],即對(duì)樣本進(jìn)行隨機(jī)變換。此類方法雖然在一定程度上增強(qiáng)了泛化性,但是并不能保留局部對(duì)稱性[19],更不能保證在每一層卷積上的等變性。群等變卷積神經(jīng)網(wǎng)絡(luò)(Group equivariant CNN,G-CNN)則能較好地解決這一問題[20],其不僅具有平移不變性,而且還具有旋轉(zhuǎn)和鏡像不變性。
為有效利用樣本圖片的局部旋轉(zhuǎn)對(duì)稱性和鏡像對(duì)稱性,提高特征提取能力,本文提出一種基于G-CNN 的度量元學(xué)習(xí)算法。通過由群等變卷積構(gòu)成的4 層映射網(wǎng)絡(luò)學(xué)習(xí)一個(gè)合適的度量空間,根據(jù)查詢集中樣本離原型點(diǎn)的距離完成分類。
元學(xué)習(xí)的目標(biāo)是跨任務(wù)的泛化??紤]一個(gè)任務(wù)分布P(T),即該模型所適配的數(shù)據(jù)的全體,目的是使這個(gè)模型可以適應(yīng)這個(gè)任務(wù)分布P(T)。與傳統(tǒng)機(jī)器學(xué)習(xí)不同,元學(xué)習(xí)不是根據(jù)每個(gè)樣本來優(yōu)化,而是根據(jù)元任務(wù)來優(yōu)化。每個(gè)元任務(wù)包含一個(gè)支持集和一個(gè)對(duì)應(yīng)的查詢集。在n-wayk-shot 元學(xué)習(xí)問題中,對(duì)于每個(gè)元任務(wù)定義支持集S和查詢集Q,支持集和查詢集中包含n個(gè)類別的樣本,支持集中每類樣本只存在k個(gè),查詢集中每類樣本個(gè)數(shù)不定,支持集S定義如式(1)所示:
其中:xi表示樣本的D維向量表示;yi表示樣本對(duì)應(yīng)的標(biāo)簽;n表示樣本類別總數(shù)。查詢集Q取自數(shù)據(jù)集中和支持集S同類別但不同的樣本,不帶標(biāo)簽。圖1給出了5-way 1-shot 元學(xué)習(xí)問題中訓(xùn)練時(shí)所使用的的支持集和查詢集示例。
圖1 5-way 1-shot 元學(xué)習(xí)問題中元訓(xùn)練使用的支持集和查詢集示例Fig.1 Example of support set and query set using in meta-training for 5-way 1-shot meta-learning problems
在訓(xùn)練階段,從P(T)的訓(xùn)練數(shù)據(jù)集上采樣訓(xùn)練元任務(wù),通過元任務(wù)對(duì)損失函數(shù)進(jìn)行最小化,從而優(yōu)化模型參數(shù)。在訓(xùn)練結(jié)束后,從同取自P(T)未參與訓(xùn)練的測(cè)試數(shù)據(jù)集(測(cè)試集中的樣本和訓(xùn)練集中的樣本類別不同)中采樣測(cè)試元任務(wù),對(duì)訓(xùn)練好的模型進(jìn)行測(cè)試。
盡管現(xiàn)階段的神經(jīng)網(wǎng)絡(luò)研究缺少理論支撐,但是大量經(jīng)驗(yàn)證據(jù)表明,卷積權(quán)值共享和網(wǎng)絡(luò)深度對(duì)于神經(jīng)網(wǎng)絡(luò)的效果起到了重要作用[21-22]。卷積權(quán)值共享的有效性依賴于其在多數(shù)感知任務(wù)中都具有平移不變性,即預(yù)測(cè)標(biāo)簽的函數(shù)和數(shù)據(jù)分布對(duì)于平移變換都近似于不變。由于平移不變性,共享權(quán)重的卷積核可以從圖像的局部區(qū)域提取特征,并且參數(shù)量遠(yuǎn)少于全連接網(wǎng)絡(luò)[23],同時(shí)能夠?qū)W到更多有效的變換信息[24-25]。卷積層可以有效地應(yīng)用于深度網(wǎng)絡(luò)中,因?yàn)檫@種網(wǎng)絡(luò)中的所有層都具有平移不變性:將圖片平移后再送入若干卷積層得到的結(jié)果,與將原圖直接送入相同卷積層再對(duì)特征圖進(jìn)行平移所得到的結(jié)果相同[26]。因此,為提高特征提取能力,本文使用G-CNN 來構(gòu)建映射網(wǎng)絡(luò),使映射網(wǎng)絡(luò)對(duì)具有旋轉(zhuǎn)對(duì)稱的特征和鏡像對(duì)稱的特征也能保持不變性。映射網(wǎng)絡(luò)使用4 層G-CNN 構(gòu)建,每層由卷積核、batch-norm、relu 激活函數(shù)和最大池化層組成。
對(duì)于輸入的2 維圖片,卷積是不斷平移卷積核和特征圖做點(diǎn)積運(yùn)算的過程,以群G上的函數(shù)代替平移就得到了群卷積,如式(2)所示:
其中:Z2是2 維圖片上的整數(shù)平移群;群運(yùn)算是加法(n,m)+(p,q)=(n+p,m+q);f是輸入的特征圖;φ是卷積核。f和φ都是Z2上的函數(shù),只適用于群卷積的第1 層,但由于卷積輸出的結(jié)果是離散群G上的函數(shù),因此第1 層后的卷積如式(3)所示:
其中:輸入的特征圖f是群G上的函數(shù)。
令h=uh,等變性證明如式(4)所示:
映射網(wǎng)絡(luò)中的非線性單元包括激活函數(shù),可以將非線性單元看作一個(gè)映射:v:R →R,非線性單元作用于特征圖f可以視為一系列操作算子的組合,如式(5)所示:
因此,使用非線性單元處理特征圖后依然能保持等變性。
池化可以分解為不帶步長的池化和下采樣[27]兩部分。對(duì)于不帶步長的池化,定義池化操作為P,作用于特征圖f的最大池化如式(6)所示(平均池化同理):
其中:gU是G的子群U上的一個(gè)g變換。在G-CNN中,下采樣表示在G的子群H上下采樣。例如:對(duì)輸入2 維圖片做步長為2 的最大池化,等價(jià)于先進(jìn)行不帶步長的池化,再在Z2的子群H={(2i,2j)|(i,j)∈Z2}上進(jìn)行下采樣。
對(duì)于具有90°旋轉(zhuǎn)對(duì)稱特征的圖片,群G使用p4 群;對(duì)于具有90°旋轉(zhuǎn)對(duì)稱和鏡像對(duì)稱的特征,群G使用p4m 群[28]。p4 群的群元定義如式(7)所示:
其中:0≤r<4,r=0 表示無旋 轉(zhuǎn),r=1 表示旋轉(zhuǎn)90°;(u,v)∈Z2,表示在二維平面上的水平和垂直移動(dòng)。群運(yùn)算為矩陣乘法。對(duì)于輸入的特征圖上的某點(diǎn)(x,y),p4 群作用于點(diǎn)(x,y)的運(yùn)算如式(8)所示:
其中:m=0 或1,1 表示鏡像翻轉(zhuǎn),其余定義與p4 群相同,群運(yùn)算為矩陣乘法。作用于輸入特征圖上某點(diǎn)(x,y)的運(yùn)算如式(10)所示:
當(dāng)群G使用p4群時(shí),第1層的G-CNN 是Z2-p4 卷積層,操作如圖2 所示,依次將卷積核旋轉(zhuǎn)90°,得到4 組卷積核,分別與輸入圖片做卷積,得到4 組映射特征。第一層后面的G-CNN 是p4-p4 卷積,操作如圖3所示,對(duì)于前層輸入的4 組映射特征,卷積核依次旋轉(zhuǎn)90°得到4 組卷積核,然后每組卷積核依次和輸入的4 組特征做卷積,將得到的結(jié)果求和得到輸出特征。使用p4m 群構(gòu)建映射網(wǎng)絡(luò)時(shí),卷積核需要額外進(jìn)行鏡像翻轉(zhuǎn),因此,卷積核的數(shù)目是8 組,得到的輸出特征也是8 組,操作與使用p4 群類似。
圖2 Z2-p4 卷積層示意圖Fig.2 Schematic diagram of Z2-p4 convolution layer
圖3 p4-p4 卷積層示意圖Fig.3 Schematic diagram of p4-p4 convolution layer
本文算法基于以下基本假設(shè):存在一個(gè)空間,在這個(gè)空間中,屬于相同類別的樣本距離近,不同類別的樣本距離遠(yuǎn),這樣就可以通過簡單度量函數(shù)進(jìn)行分類。本文算法是通過學(xué)習(xí)一個(gè)映射網(wǎng)絡(luò)將樣本映射到合適的度量空間,然后通過簡單度量方法完成分類。在n-wayk-shot 元學(xué)習(xí)問題中,對(duì)于每個(gè)元任務(wù),支持集中每類有k個(gè)樣本,支持集經(jīng)過映射網(wǎng)絡(luò)映射到度量空間后,每一類就有k個(gè)表示,取每類k個(gè)表示的均值作為該類在度量空間中的代表。每個(gè)類在度量空間的代表稱為該類的原型點(diǎn)cj,計(jì)算公式如式(11)所示:
其中:k表示支持集中每類樣本的個(gè)數(shù);fθ表示映射網(wǎng)絡(luò);(xi,yj)表示輸入的樣本和對(duì)應(yīng)的標(biāo)簽。查詢集經(jīng)過同樣的映射網(wǎng)絡(luò)映射到度量空間中,利用距離計(jì)算函數(shù)d來計(jì)算查詢集中待分類樣本到每類原型點(diǎn)的距離,再利用softmax 函數(shù)計(jì)算屬于每個(gè)類的概率,如式(12)所示:
最后,使用交叉熵作為損失函數(shù),如式(13)所示:
通過Adam 優(yōu)化器來最小化損失函數(shù),從而優(yōu)化映射網(wǎng)絡(luò)的參數(shù),不斷從訓(xùn)練集中抽取樣本組成元任務(wù)來訓(xùn)練模型,直到得到一個(gè)能很好地將訓(xùn)練樣本映射到合適度量空間的模型。
本文提出的基于群等變卷積的度量元學(xué)習(xí)算法(Metric Meta-learning algorithm Based on Group Equivariant Convolution,MMBOGEC)如算法1 所示。
算法1MMBOGEC
輸入訓(xùn)練集D={(x1,y1),(x2,y2),…,(xN,yN)}
輸出模型在測(cè)試集上的分類準(zhǔn)確率
1)在訓(xùn)練集中隨機(jī)選取n個(gè)類,對(duì)于選取的每個(gè)類,取k個(gè)樣本組成支持集,取Nq個(gè)樣本組成查詢集。
2)通過映射網(wǎng)絡(luò)得到支持集樣本在度量空間中的表示,取每個(gè)類所有樣本在度量空間中特征表示的均值作為該類的原型點(diǎn)。
3)利用同樣的映射網(wǎng)絡(luò)得到查詢集樣本在度量空間中的表示,利用距離計(jì)算公式計(jì)算查詢集樣本在度量空間中的表示到每個(gè)類原型點(diǎn)的距離,利用softmax 函數(shù)計(jì)算屬于每個(gè)類的概率,將概率最大的類別作為預(yù)測(cè)類別。
4)使用交叉熵作為損失函數(shù)更新?lián)p失J。
5)使用Adam 優(yōu)化器最小化損失J來更新網(wǎng)絡(luò)參數(shù)。
6)重復(fù)步驟1~步驟5,直到損失J不再下降。
7)在測(cè)試集中生成若干個(gè)元任務(wù),每個(gè)元任務(wù)隨機(jī)選取n個(gè)類,對(duì)于選取的每個(gè)類,取k個(gè)樣本組成支持集,取Nq個(gè)樣本組成查詢集,將這些元任務(wù)輸入訓(xùn)練好的模型,得到分類準(zhǔn)確率,最后將分類準(zhǔn)確率的平均值作為輸出結(jié)果。
本文在常用的小樣本數(shù)據(jù)集miniImageNet 和Omniglot 上進(jìn)行實(shí)驗(yàn)。
miniImageNet 數(shù)據(jù)集包含60 000 張彩色圖片,分為100 個(gè)類,每個(gè)類600 張。首先將所有圖片處理成84 像素×84 像素大小,將其中的64 類作為訓(xùn)練集,16 類作為驗(yàn)證集,剩下的20 類作為測(cè)試集。本文使用64 類來訓(xùn)練模型,驗(yàn)證集僅僅用來判斷模型泛化性的好壞,不參與模型的參數(shù)優(yōu)化。
輸入的樣本圖片經(jīng)過映射網(wǎng)絡(luò)得到其在度量空間中的特征表示,映射網(wǎng)絡(luò)包含4 層由G-CNN 構(gòu)成的卷積,每一層使用64個(gè)3×3卷積核,包含batch-norm、relu 激活函數(shù)以及3×3 的最大池化層。最后將得到的特征表示展開成一維向量,利用距離計(jì)算函數(shù)計(jì)算其到各個(gè)原型點(diǎn)的距離,將距離最近的類別作為預(yù)測(cè)標(biāo)簽。以交叉熵作為損失函數(shù),不添加正則項(xiàng)損失,學(xué)習(xí)率設(shè)置為10-3,使用Adam 優(yōu)化器對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行優(yōu)化。
針對(duì)miniImageNet 數(shù)據(jù)集常用的有兩種訓(xùn)練方法,分別是5-way 1-shot 和5-way 5-shot。5-way 1-shot訓(xùn)練方法先任意地從訓(xùn)練集中選5 個(gè)類別,每個(gè)類別包含1 個(gè)樣本,總計(jì)5 個(gè)樣本作為支持集,再從前面5 類中每類選取若干個(gè)不同的樣本(本文實(shí)驗(yàn)中設(shè)置為15 個(gè))作為查詢集,使模型根據(jù)支持集來分類查詢集。5-way 5-shot 訓(xùn)練方法將支持集每類選取樣本數(shù)改為5,其余和前面一致。當(dāng)驗(yàn)證集上的驗(yàn)證損失不再下降時(shí),停止訓(xùn)練模型,在測(cè)試集上驗(yàn)證模型的效果,測(cè)試方法和訓(xùn)練方法保持一致,測(cè)試使用隨機(jī)產(chǎn)生的600 個(gè)元任務(wù),以平均準(zhǔn)確率作為評(píng)估指標(biāo)。
4.1.1 不同距離計(jì)算公式對(duì)實(shí)驗(yàn)結(jié)果的影響
不同距離的度量公式會(huì)對(duì)算法的實(shí)驗(yàn)結(jié)果產(chǎn)生影響,本文使用常用的4 種距離計(jì)算公式進(jìn)行測(cè)試,分別是歐式距離、余弦距離、切比雪夫距離和曼哈頓距離,測(cè)試結(jié)果對(duì)比如表1 所示??梢钥闯?,在miniImageNet 數(shù)據(jù)集5-way 1-shot 和5-way 5-shot 方法中,歐氏距離作為距離計(jì)算公式最有效,其次是曼哈頓距離,切比雪夫距離最差。
表1 使用不同距離計(jì)算公式的實(shí)驗(yàn)結(jié)果對(duì)比Table 1 Comparison of experimental results using different distance calculation formulas %
4.1.2 消融實(shí)驗(yàn)
為驗(yàn)證本文算法的有效性,分別使用p4 群、p4m群和普通CNN 構(gòu)建映射網(wǎng)絡(luò)行實(shí)驗(yàn),對(duì)比實(shí)驗(yàn)結(jié)果如表2 所示??梢钥闯觯翰皇褂萌旱茸兙矸e的方法,實(shí)驗(yàn)結(jié)果最差;使用p4 群的方法,實(shí)驗(yàn)結(jié)果優(yōu)于使用普通CNN 的方法,表明在本實(shí)驗(yàn)中,具有旋轉(zhuǎn)不變性的方法比不具有旋轉(zhuǎn)不變性的方法更有效;使用p4m 群的方法,實(shí)驗(yàn)效果最好,表明利用旋轉(zhuǎn)不變性和鏡像對(duì)稱不變性能有效提高元學(xué)習(xí)的自適應(yīng)性。
表2 消融實(shí)驗(yàn)結(jié)果對(duì)比Table 2 Comparison of ablation experimental results %
4.1.3 G-CNN 層數(shù)對(duì)實(shí)驗(yàn)結(jié)果的影響
為進(jìn)一步驗(yàn)證群等變卷積的有效性,在部分卷積層上使用群等變卷積進(jìn)行實(shí)驗(yàn),實(shí)驗(yàn)結(jié)果如表3所示,其中第1 列表示使用群等變卷積的卷積層,如1 表示僅在第1 層使用,其余層使用普通CNN??梢钥闯?,在5-way 1-shot 和5-way 5-shot 的實(shí)驗(yàn)中,僅僅在單層中使用群等變卷積,不論是在哪一層使用,實(shí)驗(yàn)結(jié)果都相差不大,表明僅在某一層具有等變性不能很好地提升元學(xué)習(xí)的自適應(yīng)性。隨著使用群等變卷積層數(shù)的增加,實(shí)驗(yàn)效果隨之提升,完整的4 層群等變卷積網(wǎng)絡(luò)效果最好,表明整個(gè)網(wǎng)絡(luò)都具有等變性才能更好地適用于元學(xué)習(xí)問題。
表3 在不同卷積層使用G-CNN 的實(shí)驗(yàn)結(jié)果對(duì)比Table 3 Comparison of experimental results using G-CNN in different convolutional layers %
4.1.4 與4 層元學(xué)習(xí)算法的實(shí)驗(yàn)結(jié)果對(duì)比
將本文算法與傳統(tǒng)4 層元學(xué)習(xí)算法進(jìn)行對(duì)比,實(shí)驗(yàn)結(jié)果如表4 所示(加粗?jǐn)?shù)據(jù)表示最優(yōu)數(shù)據(jù))??梢钥闯?,無論是5-way 1-shot 還是5-way 5-shot,本文算法性能都優(yōu)于傳統(tǒng)4 層元學(xué)習(xí)算法。
表4 不同算法在miniImageNet數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果對(duì)比Table 4 Comparison of experimental results of different algorithms on miniImageNet dataset %
Omniglot數(shù)據(jù)集包含50種不同語言,共計(jì)1 623種手寫字符,每種字符包含20個(gè)樣本,每個(gè)樣本由不同人書寫。本文將樣本圖片大小統(tǒng)一為28 像素×28 像素,使用其中的1 028 類作為訓(xùn)練集,423 類作為測(cè)試集,剩下的作為驗(yàn)證集。
輸入的樣本圖片經(jīng)過映射網(wǎng)絡(luò)得到其在度量空間中的特征表示,映射網(wǎng)絡(luò)包含4 層由G-CNN 構(gòu)成的卷積層,每層使用64 個(gè)3×3 卷積核、batch-norm、relu 激活函數(shù)以及3×3 的最大池化層。在度量空間中使用歐氏距離計(jì)算查詢集到原型點(diǎn)的距離,將距離最短的原型點(diǎn)對(duì)應(yīng)的標(biāo)簽作為預(yù)測(cè)標(biāo)簽,以交叉熵作為損失函數(shù),不添加正則項(xiàng)損失,學(xué)習(xí)率設(shè)置為10-3,使用Adam 優(yōu)化器對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行優(yōu)化。
Omniglot 數(shù)據(jù)集常用的有4 種訓(xùn)練方法,分別是5-way 1-shot、5-way 5-shot、20-way 1-shot 和20-way 5-shot,測(cè)試時(shí)同樣使用對(duì)應(yīng)的方法。測(cè)試使用隨機(jī)產(chǎn)生的1 000 個(gè)元任務(wù),以平均準(zhǔn)確率作為最后的結(jié)果。
本文算法與傳統(tǒng)4 層元學(xué)習(xí)算法在Omniglot 數(shù)據(jù)集上實(shí)驗(yàn)結(jié)果對(duì)比如表5 所示(加粗?jǐn)?shù)據(jù)表示最優(yōu)數(shù)據(jù)),可以看出,在5-way 1-shot、5-way 5-shot 實(shí)驗(yàn)中,本文算法性能都優(yōu)于其他算法。
表5 不同算法在Omniglot 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果對(duì)比Table 5 Comparison of experimental results of different algorithms on Omniglot dataset %
本文算法針對(duì)n-wayk-shot 元學(xué)習(xí)問題,對(duì)于每個(gè)元任務(wù),需要n類支持集樣本,每類樣本包含k個(gè)實(shí)例,對(duì)q個(gè)支持集樣本進(jìn)行分類,因此每個(gè)元任務(wù)的平均復(fù)雜度為O(n×k×q)。
MMBOGEC 算法與傳統(tǒng)4 層元學(xué)習(xí)算法的參數(shù)量對(duì)比如表6 所示(加粗?jǐn)?shù)據(jù)表示最優(yōu)數(shù)據(jù))??梢钥闯?,MMBOGEC 算法參數(shù)量只比原型網(wǎng)絡(luò)算法多,而少于其他4 種算法。
表6 不同算法的參數(shù)量對(duì)比Table 6 Comparison of the number of parameters of different algorithms
針對(duì)傳統(tǒng)機(jī)器學(xué)習(xí)的自適應(yīng)性問題,本文提出一種基于群等變卷積的度量元學(xué)習(xí)算法,使用群等變卷積神經(jīng)網(wǎng)絡(luò)構(gòu)建映射網(wǎng)絡(luò),充分利用樣本圖片的局部旋轉(zhuǎn)對(duì)稱性和鏡像對(duì)稱性,將樣本圖片映射到合適的度量空間,根據(jù)所提取特征到每個(gè)類原型點(diǎn)的距離遠(yuǎn)近來實(shí)現(xiàn)分類。在Omniglot 數(shù)據(jù)集和miniImageNet數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果表明,該算法對(duì)于元學(xué)習(xí)問題的學(xué)習(xí)性能優(yōu)于傳統(tǒng)4 層元學(xué)習(xí)算法。下一步將對(duì)本文算法進(jìn)行改進(jìn),探索更有效的特征映射網(wǎng)絡(luò)和特征距離比較方法。