張曉峰,吳 剛
(中國科學技術(shù)大學 信息科學技術(shù)學院,合肥 230031)
隨著近些年來深度學習的發(fā)展,深度神經(jīng)網(wǎng)絡(luò)[1]在分類任務上取得了革命性的突破.基于深度神經(jīng)網(wǎng)絡(luò)的分類器在有充足標簽樣本為訓練數(shù)據(jù)的前提下可以達到很高的準確度.但是往往在一些場景下,有標簽的數(shù)據(jù)難以收集或者獲取這些數(shù)據(jù)成本高昂,費時費力.當數(shù)據(jù)不足時,神經(jīng)網(wǎng)絡(luò)很難穩(wěn)定訓練并且泛化能力較弱.如何在小規(guī)模數(shù)據(jù)集上有效的訓練神經(jīng)網(wǎng)絡(luò)成為當下的一個研究熱點.常見的應對小規(guī)模數(shù)據(jù)集訓練問題的措施主要有以下3種:
(1)無監(jiān)督預訓練和有監(jiān)督微調(diào)相結(jié)合的方法.通過引入和訓練數(shù)據(jù)具有相同分布的大量無標簽數(shù)據(jù)的方式,神經(jīng)網(wǎng)絡(luò)可以先收斂到一個較優(yōu)的初始點,然后再在小數(shù)據(jù)上微調(diào).但是這種方式存在一個潛在的假設(shè):無標簽的數(shù)據(jù)容易獲得而且收集成本不高,但是在一些數(shù)據(jù)難以獲取的場景中,例如醫(yī)療圖像,這種方法將無法應用.
(2)遷移學習[2]的方法.相比于第一種方法,遷移學習的要求更加寬泛,額外的無標簽數(shù)據(jù)不需要和訓練數(shù)據(jù)具有相同的分布,只要相似或者分布有重疊即可.在視覺識別當中,一些視覺的基本模式像邊緣、紋理等在自然圖像中都是共通的,這一點構(gòu)成了遷移學習的理論保證.大量的實踐表明,在源領(lǐng)域(source domain)上學習大量數(shù)據(jù)后的網(wǎng)絡(luò)再遷移到目標領(lǐng)域(target domain)上,網(wǎng)絡(luò)的性能會得到極大的提升.但是當源領(lǐng)域和目標領(lǐng)域之間差距甚大時,遷移學習是否有所幫助,目前還未有研究.
(3)數(shù)據(jù)增強[3]的方法.通過合成或者轉(zhuǎn)換的方式,從有限的數(shù)據(jù)中生成新的數(shù)據(jù),數(shù)據(jù)增強技術(shù)一直以來都是一種重要的克服數(shù)據(jù)不足的手段.傳統(tǒng)的圖像領(lǐng)域的數(shù)據(jù)增強技術(shù)是建立在一系列已知的仿射變換——例如旋轉(zhuǎn)、縮放、位移等,以及一些簡單的圖像處理手段——例如光照色彩變換、對比度變換、添加噪聲等基礎(chǔ)上的.這些變化的前提是不改變圖像的標簽,并且只能局限在圖像領(lǐng)域.這種基于幾何變換和圖像操作的數(shù)據(jù)增強方法可以在一定程度上緩解神經(jīng)網(wǎng)絡(luò)過擬合的問題,提高泛化能力.但是相比與原始數(shù)據(jù)而言,增加的數(shù)據(jù)點并沒有從根本上解決數(shù)據(jù)不足的難題;同時,這種數(shù)據(jù)增強方式需要人為設(shè)定轉(zhuǎn)換函數(shù)和對應的參數(shù),一般都是憑借經(jīng)驗知識,最優(yōu)數(shù)據(jù)增強通常難以實現(xiàn),所以模型的泛化性能只能得到有限的提升.
最近興起的一些生成模型,由于其出色的性能引起了人們的廣泛關(guān)注.例如變分自編碼網(wǎng)絡(luò)(Variational Auto-Encoding network,VAE)[4]和生成對抗網(wǎng)絡(luò)(Generative Adversarial Network,GAN)[5],其生成樣本的方法也可以用于數(shù)據(jù)增強.這種基于網(wǎng)絡(luò)合成的方法相比于傳統(tǒng)的數(shù)據(jù)增強技術(shù)雖然過程更加復雜,但是生成的樣本更加多樣,同時還可以應用于圖像編輯,圖像去噪等各種場景.本文主要介紹的是基于生成對抗網(wǎng)絡(luò)的數(shù)據(jù)增強技術(shù),并將這種方法應用于小規(guī)模數(shù)據(jù)集的分類任務.
生成模型可以分成顯式密度模型和隱式密度模型兩種.生成對抗網(wǎng)絡(luò)是一種隱式密度模型,即網(wǎng)絡(luò)沒有顯式的給出數(shù)據(jù)分布的密度函數(shù),GAN的網(wǎng)絡(luò)結(jié)構(gòu)如圖1所示,是由生成網(wǎng)絡(luò)(Generator,G)和判別網(wǎng)絡(luò)(Discriminator,D)兩部分組成.假設(shè)在低維空間Z存在一個簡單容易采樣的分布p(z),例如標準正態(tài)分布N(0,I),生成網(wǎng)絡(luò)構(gòu)成一個映射函數(shù)G:Z→X,判別網(wǎng)絡(luò)需要判別輸入是來自真實數(shù)據(jù)還是生成網(wǎng)絡(luò)生成的數(shù)據(jù).生成網(wǎng)絡(luò)輸入噪聲z,輸出生成的圖像數(shù)據(jù);判別網(wǎng)絡(luò)輸入的數(shù)據(jù)或者來自真實數(shù)據(jù)集,或者來自生成網(wǎng)絡(luò)合成的數(shù)據(jù),輸出數(shù)據(jù)為真的概率.
圖1 GAN結(jié)構(gòu)示意圖
G和D相互競爭:G試圖欺騙D從而以假亂真,而D則不斷提高甄別能力防止G合成的數(shù)據(jù)魚目混珠,理論上最終生成的數(shù)據(jù)分布Pg和真實的數(shù)據(jù)分布Pdata可以相等.可以用式(1)概括整個GAN網(wǎng)絡(luò)的優(yōu)化函數(shù):
GAN本質(zhì)上屬于無監(jiān)督學習的范疇,其判別網(wǎng)絡(luò)僅僅輸出數(shù)據(jù)真假的概率.條件生成對抗網(wǎng)絡(luò)(Conditional-GAN)[6]在GAN的基礎(chǔ)上,加入類別的信息Y,從而可以生成指定類別的數(shù)據(jù).Conditional-GAN的優(yōu)化函數(shù)可以寫成式(2):
Conditional-GAN的判別器D仍然只有一個輸出來判斷真假,而半監(jiān)督學習生成對抗網(wǎng)絡(luò)(Semi-GAN)[7]在Conditional-GAN 的基礎(chǔ)上,判別器輸出增加到K+1個(K代表數(shù)據(jù)的類別個數(shù)),K個輸出表示真實數(shù)據(jù)的分類概率,第K+1個表示數(shù)據(jù)為假的概.Conditional-GAN和Semi-GAN的結(jié)構(gòu)如圖2所示.
圖2 Conditional-GAN與Semi-GAN結(jié)構(gòu)對比
本文從數(shù)據(jù)增強的目的出發(fā),通過改進生成對抗網(wǎng)絡(luò)的結(jié)構(gòu)和訓練算法,設(shè)計了一種基于生成對抗網(wǎng)絡(luò)的數(shù)據(jù)增強技術(shù),并提出了一種新的網(wǎng)絡(luò)結(jié)構(gòu),即數(shù)據(jù)增強生成對抗網(wǎng)絡(luò)(Data Augmentation GAN,DAGAN).與其他的GAN結(jié)構(gòu)相比,我們提出的網(wǎng)絡(luò)結(jié)構(gòu)更加適用于數(shù)據(jù)增強任務,即生成的樣本和原始數(shù)據(jù)真假難分的同時,還可以做到類間可分,從而有利于分類器在在合成的數(shù)據(jù)點上學習到分類界限.在訓練算法上,本文將DAGAN的訓練過程和分類器的訓練過程相結(jié)合,并提出一種新的損失函數(shù),稱之為“2K”損失函數(shù),從而可以做到在線數(shù)據(jù)增強,即數(shù)據(jù)處理和分類器訓練可以在內(nèi)存中同步處理,不需要另外的數(shù)據(jù)存儲空間.
一般的GAN網(wǎng)絡(luò)其判別器僅僅只有一個輸出——判斷輸入的真假,如果直接用來生成數(shù)據(jù)用來做數(shù)據(jù)增強是不可行的,因為不能做到按類別生成樣本.Conditional-GAN和Semi-Supervised GAN雖然可以利用數(shù)據(jù)的標簽信息,并且按照給定的類別生成相應的數(shù)據(jù),但是相關(guān)的研究工作表明這樣的GAN結(jié)構(gòu)其生成的樣本多樣性不足,對數(shù)據(jù)增強的貢獻十分有限.因此,需要針對我們數(shù)據(jù)增強的這一特定需求,即生成的數(shù)據(jù)有利于分類器學習更加緊湊的分類界限,提升分類性能來設(shè)計網(wǎng)絡(luò)結(jié)構(gòu).基于以上考慮,從生成網(wǎng)絡(luò)的角度來看,最優(yōu)的判別網(wǎng)絡(luò)需要:
1)能夠正確地將真實數(shù)據(jù)和生成數(shù)據(jù)分類;
2)不能分辨數(shù)據(jù)是真實的還是合成的.
據(jù)此,在GAN的基礎(chǔ)上設(shè)計出適合于小規(guī)模數(shù)據(jù)增強任務的GAN網(wǎng)絡(luò)結(jié)構(gòu),即DAGAN.結(jié)構(gòu)如圖3所示.
圖3 DAGAN網(wǎng)絡(luò)結(jié)構(gòu)
這里,生成網(wǎng)絡(luò)采用Conditional-GAN的結(jié)構(gòu),隱向量z和類別信息y作為輸入,輸出對應類別的數(shù)據(jù);判別網(wǎng)絡(luò)的輸入有兩個來源——真實數(shù)據(jù)或者生成的數(shù)據(jù),輸出則變?yōu)?K個,前K個表示輸入為真實數(shù)據(jù)K類的概率,后K個表示輸入為生成數(shù)據(jù)K類的概率.
可以看出,就判別網(wǎng)絡(luò)而言,從GAN到Semi-Supervised GAN,再到本文提出的DAGAN,輸出的維度不斷增加,同時應用的領(lǐng)域也更加廣泛.就生成網(wǎng)絡(luò)而言,Conditional-GAN,Semi-Supervised GAN以及本文的DAGAN都利用了數(shù)據(jù)的標簽信息,可以根據(jù)指定的類別生成相應的數(shù)據(jù).DAGAN在利用Conditional-GAN生成器結(jié)構(gòu)的同時,又增強了判別網(wǎng)絡(luò)的判別能力,使之適用于小規(guī)模數(shù)據(jù)集的增強.表1總結(jié)了以上幾種GAN網(wǎng)絡(luò)的特點對比.
表1 幾種GAN網(wǎng)絡(luò)的對比
DAGAN的訓練分成兩個階段,第一階段為數(shù)據(jù)生成階段.生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)優(yōu)化相反的目標函數(shù),在不斷的對抗中達到平衡.與GAN不同的是,由于判別網(wǎng)絡(luò)有2K個輸出,因此相應的損失函數(shù)也將發(fā)生改變,稱之為“2K”損失函數(shù).對于判別網(wǎng)絡(luò),其損失函數(shù)如下:
對于生成網(wǎng)絡(luò),除了對應的判別真假的損失函數(shù)之外,還包括正則化項,用來保證生成的數(shù)據(jù)和真實的數(shù)據(jù)在特征層面盡可能保持相近,損失函數(shù)如(4)式所示:
其中,Lfm為正則化項,具體形式如下:
這里f(x)函數(shù)判別網(wǎng)絡(luò)中間某一層的輸出,即要求在相同類別的前提下,生成數(shù)據(jù)和真實數(shù)據(jù)特征應當相近,這進一步保證了生成數(shù)據(jù)和真實數(shù)據(jù)在同一類別下具有相同的語義.
第二階段為分類訓練階段,假設(shè)第一階段訓練完成之后,生成網(wǎng)絡(luò)已經(jīng)學習到真實數(shù)據(jù)的分布.因此在這一階段,生成網(wǎng)絡(luò)將不再進行訓練,僅僅作為一個數(shù)據(jù)的提供者,生成的數(shù)據(jù)和真實數(shù)據(jù)一起訓練分類網(wǎng)絡(luò).值得注意的是,這里不需要單獨搭建新的分類網(wǎng)絡(luò),判別網(wǎng)絡(luò)直接作為分類器進行訓練.由于判別網(wǎng)絡(luò)有2K個輸出,這里規(guī)定第i個與第k+i個輸出的概率之和表示輸入為第i(i=1,2,…,k)類數(shù)據(jù)的概率.第二階段的判別網(wǎng)絡(luò)的損失函數(shù)由兩部分構(gòu)成,分別是真實數(shù)據(jù)和生成數(shù)據(jù):
其中,
兩個階段均采用批量隨機梯度下降的算法進行參數(shù)更新,具體流程見算法1.
算法1.DAGAN批量隨機梯度下降訓練算法輸入:第一階段的迭代次數(shù)KG,第二階段的迭代次數(shù)KC,訓練集D,測試集T,批次數(shù)量B 1)數(shù)據(jù)生成階段訓練:分別采樣真實數(shù)據(jù)(x,y)~Pdata(x,y),以及隱向量數(shù)據(jù)z~P(z),隨機類別數(shù)據(jù)y~Pg.在KG次迭代中,采用隨機梯度下降的方法,交替更新生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò),損失函數(shù)分別為LG和LC.2)數(shù)據(jù)分類階段訓練:分別采樣真實數(shù)據(jù)(x,y)~Pdata(x,y),以及隱向量數(shù)據(jù)z~P(z),隨機類別數(shù)據(jù)y~Pg,在KC次迭代中,采用隨機梯度下降的算法,只更新判別網(wǎng)絡(luò),損失函數(shù)為L’C.3)在測試集上測試判別網(wǎng)絡(luò)的準確率.
為了驗證DAGAN的生成能力以及生成樣本能否提升分類器的準確率,我們分別在3個數(shù)據(jù)集上做了驗證實驗,分別為CIFAR-10、SVHN以及KDEF數(shù)據(jù)集.實驗中的網(wǎng)絡(luò)結(jié)構(gòu)都是基于DCGAN[8]這個網(wǎng)絡(luò)搭建,詳細的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)如表2所示.
表2 CIFAR10實驗網(wǎng)絡(luò)參數(shù)與網(wǎng)絡(luò)結(jié)構(gòu)(SVHN與KDEF數(shù)據(jù)集實驗與之類似)
這里需要說明:G,D,T-Conv,Conv,NIN,NL分別表示生成網(wǎng)絡(luò),判別網(wǎng)絡(luò),反卷積,卷積,Network in Network,非線性激活函數(shù).
CIFAR-10數(shù)據(jù)集總共包含60 000張RGB圖片,其中50 000張為訓練圖片,10 000張為測試集圖片.圖片為32×32的分辨率,總共可以分成10類.為了探究各種數(shù)據(jù)增強方式對于不同程度的小規(guī)模數(shù)據(jù)集的影響,我們?nèi)藶榈貜脑摂?shù)據(jù)集中抽取不同數(shù)量的子數(shù)據(jù)集,每類從50到1000不等.實驗主要對比以下幾種不同的數(shù)據(jù)增強方式:(1)不采用任何的數(shù)據(jù)增強方式(C);(2)傳統(tǒng)的基于仿射變換和圖像操作的數(shù)據(jù)增強方式(C_aug);(3)GAN在每一類上分別訓練,然后每一類單獨生成數(shù)據(jù)(Vanilla GAN);(4)Semi-Supervised GAN 生成數(shù)據(jù)(Semi GAN);(5)本文所提出的方法(DAGAN);(6)本文所提出的方法加上傳統(tǒng)的數(shù)據(jù)增強技術(shù)(DAGAN_aug).實驗對比了不同方法下訓練出來的分類器在測試集上的分類準確率(Acc),結(jié)果見表3.
表3 不同數(shù)據(jù)增強方式在CIFAR10數(shù)據(jù)集上測試集的準確率(%)
從實驗結(jié)果可以看出,DAGAN_aug是所有方法中對分類器提升最顯著的,表明DAGAN可以在傳統(tǒng)數(shù)據(jù)增強的基礎(chǔ)上進一步提升模型的性能,突破傳統(tǒng)數(shù)據(jù)增強的瓶頸.另外可以看出DAGAN在數(shù)據(jù)量較少的時候(每類圖片數(shù)量小于500張)要優(yōu)于Vanilla GAN和Semi GAN,說明本文針對數(shù)據(jù)增強目的設(shè)計的DAGAN網(wǎng)絡(luò)結(jié)構(gòu)和訓練算法更加有利于分類器的性能提升.
SVHN[9]是真實世界的街道門牌號碼識別數(shù)據(jù)集,每張圖片代表0-9中的一個數(shù)字,分辨率為32×32.由于每種圖片中可能包含不止一種數(shù)字,而標簽為中心的數(shù)字.傳統(tǒng)的數(shù)據(jù)增強方式例如翻轉(zhuǎn)、移位等在這樣的數(shù)據(jù)中將不能應用,因為這些轉(zhuǎn)換方式可能會改變圖像的標簽.同樣地,表4給出了不同種數(shù)據(jù)增強方式在SVHN數(shù)據(jù)集上的性能對比.實驗僅僅考慮了3種數(shù)據(jù)增強方式的對比,即(1)不采用任何的數(shù)據(jù)增強方式(C);(2)Semi-Supervised GAN生成數(shù)據(jù)(Semi GAN);(3)本文提出的方法(DAGAN).
表4 不同數(shù)據(jù)增強方式在SVHN數(shù)據(jù)集上測試集的準確率(%)
實驗結(jié)果和CIFAR10數(shù)據(jù)集是一致的,在數(shù)據(jù)量較少的情況下,DAGAN能夠最大程度的提升分類器的分類性能,且優(yōu)于Semi GAN的方法.有一點需要注意,當數(shù)據(jù)量較多時(每類圖片數(shù)為500張),Semi GAN和DAGAN兩種方法幾乎都不起作用,這主要是因為對于相對比較簡單的SVHN數(shù)據(jù)集,當訓練數(shù)據(jù)達到一定規(guī)模后,限制網(wǎng)絡(luò)性能的因素不再是數(shù)據(jù),而是分類網(wǎng)絡(luò)的結(jié)構(gòu)還有分類算法.
KDEF[10]數(shù)據(jù)集是一種人臉表情數(shù)據(jù)集,包含35個男性和35個女性,年齡在20至30歲之間.沒有胡須,耳環(huán)或眼鏡,且沒有明顯的化妝.7種不同的表情,每個表情有5個角度.總共4900張彩色圖,尺寸為562×762像素.實驗中我們僅采用正面角度,因此只有490張圖片,根據(jù)表情進行分類.
本次實驗生成網(wǎng)絡(luò)的結(jié)構(gòu)沒有變化,與表2類似,判別網(wǎng)絡(luò)采用VGG-16,由于數(shù)據(jù)量過少,因此我們采用的VGG-16是在ImageNet數(shù)據(jù)集上預訓練過的.實驗對比了以下幾種數(shù)據(jù)增強方式的性能:(1)不采用任何數(shù)據(jù)增強方式,僅僅是預訓練的分類器(C);(2)GAN在每一類上分別訓練,然后每一類單獨生成數(shù)據(jù)(Vanilla GAN);(3)Semi-Supervised GAN生成數(shù)據(jù)(Semi GAN);(4)本文所提的方法(DAGAN).實驗結(jié)果如表5所示,從結(jié)果來看,DAGAN依然是性能最好的結(jié)構(gòu),同時說明DAGAN可以和預訓練的策略相結(jié)合,進一步提升分類器的性能,突破數(shù)據(jù)增強技術(shù)的瓶頸.
表5 不同數(shù)據(jù)增強方式在KDEF數(shù)據(jù)集上測試集的準確率(%)
以上3個數(shù)據(jù)集的實驗說明了DAGAN結(jié)構(gòu)的可行性和有效性,為了進一步表明DAGAN生成的圖片和原始圖片具有相同的語義,而且呈現(xiàn)出內(nèi)容上的多樣性,這一部分將展示3個數(shù)據(jù)集上DAGAN生成的數(shù)據(jù)樣本,并和原始數(shù)據(jù)相比較,如圖4所示.
從生成圖片來看,CIFAR-10數(shù)據(jù)集每一行都是有著漸變的效果,這是通過對隱變量z差值實現(xiàn)的;而每一列都是一個不同的類別,這是通過控制類別信息y實現(xiàn)的.SVHN數(shù)據(jù)集每一行都是屬于相同的類別,而每一列圖片的z保持相同,所以每一列的圖片具有相同的風格.以上都說明DAGAN生成的圖片是可編輯的,同時也可以看出生成的圖像呈現(xiàn)比較豐富的多樣性,從而印證了DAGAN可以用于數(shù)據(jù)增強任務.
圖4 CIFAR-10數(shù)據(jù)集、SVHN數(shù)據(jù)集和KDEF數(shù)據(jù)集原始圖片和生成圖片對比
由于深度神經(jīng)網(wǎng)絡(luò)在小規(guī)模數(shù)據(jù)集上難以訓練,容易出現(xiàn)過擬合的問題,本文提出一種基于生成對抗網(wǎng)絡(luò)的數(shù)據(jù)增強技術(shù),通過在大量的實驗,以及和其他模型的對比,驗證了所提方法的可行性和有效性.DAGAN既可以有效提升分類器的分類性能,同時生成的圖像數(shù)據(jù)和真實數(shù)據(jù)相比具有語義的相似性和內(nèi)容的多樣性.