曾 琦, 向德華2,李 寧2,肖紅光
(1.長沙理工大學 計算機與通信工程學院,湖南 長沙 410114;2.湖南省計量檢測研究院,湖南 長沙 410014)
Ian J.Goodfellow在2014年首次提出生成對抗網(wǎng)絡(luò)(Generative Adversarial Networks,GAN)[1]的概念。GAN包括生成器(generator)與判別器(discriminator)兩個組成部分,生成器與判別器之間是一種博弈對抗的關(guān)系,判別器的目標是正確區(qū)分真實樣本與生成的偽樣本,而生成器的目標則是輸出的偽樣本能夠使判別器做出誤判。兩者經(jīng)過博弈對抗之后達到納什均衡[2],生成擬合真實數(shù)據(jù)的樣本。GAN目前應(yīng)用十分廣泛,如圖像識別、圖像超分辨率[3]、灰度圖像上色[4]、通信加密[5]、圖像合成[6]、根據(jù)文字生成圖片[7]等領(lǐng)域。在圖像識別領(lǐng)域,基于GAN的方法[8]雖然擁有很高的識別率,但與傳統(tǒng)有監(jiān)督學習方法如卷積神經(jīng)網(wǎng)絡(luò)一樣需要使用大量有標簽樣本進行訓練。在面對某些情況時這個問題影響了圖像識別效果,例如在進行醫(yī)學圖像分析時,雖然能夠采集大量樣本,但是需要富有經(jīng)驗的醫(yī)生進行樣本標注,不僅費時費力而且浪費了無標簽樣本中的有效信息。針對這個問題,本文結(jié)合半監(jiān)督生成對抗網(wǎng)絡(luò)[9](Semi-Supervised Generative Adversarial Networks,SSGAN)與深度卷積生成對抗網(wǎng)絡(luò)[10](Deep Convolutional Generative Adversarial Networks,DCGAN)建立半監(jiān)督深度生成對抗網(wǎng)絡(luò)模型(Semi Supervised Deep Convolutional Generative Adversarial Networks,SS-DCGAN)。SS-DCGAN使用有標簽樣本和無標簽樣本進行訓練,并且使用深度卷積網(wǎng)絡(luò)作為生成器與分類器。訓練后抽取模型的分類器部分用于圖像識別,實驗結(jié)果表明,該方法僅使用少量有標簽樣本即可達到與其他圖像識別方法同水平的識別率,解決了有標簽樣本數(shù)量較少的情況時識別效果不佳的問題。
生成對抗網(wǎng)絡(luò)模型由生成器和判別器兩個部分組成,生成器根據(jù)隨機噪聲生成偽樣本,判別器判斷輸入數(shù)據(jù)真?zhèn)巍I善骱团袆e器是能夠?qū)崿F(xiàn)生成樣本和判別真?zhèn)蔚挠成浜瘮?shù)模型,如多層感知器。生成對抗網(wǎng)絡(luò)的模型流程圖如圖1所示。
圖1 GAN流程圖
在生成器和判別器的博弈對抗中,生成器根據(jù)隨機噪聲z生成擬合真實數(shù)據(jù)Pdata的偽樣本G(z),G(z)和真實數(shù)據(jù)x一起輸入判別器,判別器輸出判別結(jié)果D(x)和D(G(z)),即輸入數(shù)據(jù)判別為“真”的概率。判別器得到判別結(jié)果后最小化實際輸出和期望輸出的交叉熵,而生成器通過判別器反饋最大化偽樣本判別為“真”的概率D(G(z)),此時生成器與判別器完成了一次優(yōu)化更新。但為了使生成器和判別器維持同水平博弈對抗,避免判別器過快達到最優(yōu)解使模型無法收斂,生成器與判別器的優(yōu)化更新頻率并非同步的,生成器更新幾次之后判別器才更新一次。經(jīng)過博弈對抗生成器與判別器達到納什均衡,此時D(G(z))=0.5,生成的樣本擬合真實數(shù)據(jù)分布。
Ez~Pz(z)[log(1-D(G(z)))]
(1)
半監(jiān)督生成對抗網(wǎng)絡(luò)(SSGAN)是2016年OpenAI提出的一種GAN改進模型。原始GAN訓練時使用無標簽樣本,生成沒有類別信息的樣本,與之相比SSGAN使用有標簽樣本和無標簽樣本共同訓練并生成帶有類別信息的樣本[9],也因此需要將GAN的判斷器替換為多分類器。
SSGAN模型的流程圖如圖2所示。隨機噪聲z通過生成器生成的偽樣本G(z)與k類有標簽樣本xl和無標簽樣本xu輸入分類器,輸出k+1維分類結(jié)果。前k維輸出代表對應(yīng)類置信度,第k+1維代表判定為“偽”的置信度。
圖2 SSGAN流程圖
深度卷積生成對抗網(wǎng)絡(luò)(DCGAN)由Alec Radford在2015年提出。DCGAN引入深度卷積網(wǎng)絡(luò)作為GAN的生成器和判別器,利用其強大的特征提取能力提升了GAN模型的表現(xiàn)[11]。
DCGAN與GAN比較做出了以下改變[10]:
① 分別用步幅卷積(Strided Convolutions)和微步幅卷積(Fractional-Strided Convolutions)替換了判別器和生成器中的池化層。
② 在生成器和判別器上使用了批量歸一化(Batchnorm)。批量歸一化可以解決初始化差的問題并在進行最小化交叉熵時將梯度傳播到每一層,
③ 除了輸出層使用Tanh(雙曲正切函數(shù))激活函數(shù),生成器所有層都使用ReLU(Rectified Linear Unit)激活函數(shù)。判別器所有層都使用Leaky ReLU(Leaky Rectified Linear Unit)激活函數(shù)。
結(jié)合SSGAN模型和DCGAN模型的特點,建立半監(jiān)督深度生成對抗網(wǎng)絡(luò)模型(SS-DCGAN)。SS-DCGAN引入深度卷積網(wǎng)絡(luò)作為生成器和分類器,與其他有監(jiān)督學習算法只使用有標簽樣本不同,SS-DCGAN訓練時還使用無標簽樣本。
圖3為無標簽樣本在分類中的作用。其中黑白色點代表有標簽樣本而灰色點代表無標簽樣本,虛線代表分類面。如果在訓練時只考慮有標簽樣本則會得到垂直的分類面,而考慮無標簽樣本之后分類器可以通過樣本整體分布得到更為準確的分類面。
圖3 無標簽樣本在分類中的作用
在進行訓練時,無標簽樣本雖然沒有類別信息,但是無標簽樣本的分布有助于學習數(shù)據(jù)的整體分布,幫助分類器更加準確分類。根據(jù)真實數(shù)據(jù)的整體分布,模型利用深度卷積網(wǎng)絡(luò)強大的特征提取能力生成擬合真實數(shù)據(jù)分布的偽樣本,這些樣本和真實數(shù)據(jù)共同輸入并訓練分類器,增加了訓練樣本個數(shù)。經(jīng)過足夠多樣本訓練后,分類器擁有很高的識別率。抽取訓練好的模型中分類器部分優(yōu)化調(diào)整后即得到用于圖像分類的網(wǎng)絡(luò)結(jié)構(gòu)。
SS-DCGAN引入7層和9層的深度卷積網(wǎng)絡(luò)作為生成器和分類器。
SS-DCGAN的生成器模型如圖4所示。生成器是一個7層的轉(zhuǎn)置卷積網(wǎng)絡(luò),包括2層全連接層和5層轉(zhuǎn)置卷積層。100+k維隨機噪聲z首先通過2層全連接層進行維度轉(zhuǎn)換,其中k為類別數(shù),然后經(jīng)過5層轉(zhuǎn)置卷積層進行反卷積轉(zhuǎn)換,最后輸出一個m×n×p的張量,即生成的圖像樣本。其中m×n代表圖像分辨率,p代表圖像通道數(shù)。
分類器的模型如圖5所示。分類器為一個9層的卷積網(wǎng)絡(luò),包括2層全連接層和7層卷積層。m×n×p的圖像樣本首先經(jīng)過7層卷積層進行特征提取,然后利用2層全連接層對卷積層提取的特征信息進行整合,最后輸出k+1維分類結(jié)果。前k維輸出為對應(yīng)類的置信度,第k+1維為判定為“偽”的置信度。
圖4 SS-DCGAN生成器結(jié)構(gòu)
圖5 SS-DCGAN分類器結(jié)構(gòu)
SS-DCGAN使用部分有標簽樣本和較多的無標簽樣本共同進行訓練,訓練的過程實質(zhì)上為生成器與分類器之間的博弈對抗。在生成器和分類器的博弈對抗中,分類器最小化實際輸出和期望輸出的交叉熵,而生成器通過判別器反饋最大化偽樣本判別為“真”的概率D(G(z)),此時生成器與判別器完成一次優(yōu)化。經(jīng)過多輪反饋和更新后,生成器和分類器達到納什均衡,此時生成的偽樣本擬合真實數(shù)據(jù)分布,分類器識別率最高。抽取訓練后模型的分類器部分經(jīng)過優(yōu)化調(diào)整即得到用于圖像識別的網(wǎng)絡(luò)結(jié)構(gòu)。
模型的損失函數(shù)如式(2)[9]所示,由兩部分組成,前兩項對應(yīng)真實數(shù)據(jù)的損失函數(shù),后半部分對應(yīng)生成樣本的損失函數(shù)。使用Adam(Adaptive Moment Estim-Ation)優(yōu)化器最小化損失函數(shù),對生成器和分類器進行優(yōu)化更新。
L=-Ex,y~Pdata(x,y)[logPmodel(y|x)]-
Ex,y~Pdata(x,y)[1-logPmodel(y=K+1|x)]-
Ex~G[logPmodel(y=K+1|x)]
(2)
將訓練之后得到的SS-DCGAN模型用于圖像識別。如圖6所示,將訓練好的SS-DCGAN中的分類器部分抽取出來,因為第k+1維輸出是判定為“偽”的置信度,所以忽略第k+1維中間輸出后將剩余部分輸出連接Softmax層,最終得到k維輸出,對應(yīng)輸入圖像在k個分類上的置信度。
圖6 SS-DCGAN圖像分類結(jié)構(gòu)
在MNIST和CIFAR-10兩個公開數(shù)據(jù)集上進行了圖像識別實驗。實驗環(huán)境為 Intel?CoreTMi7-7700k CPU@ 4.2 GHz處理器,16 GB運行內(nèi)存,Nvidia GeForce GTX 1080 GPU,TensorFlow框架。
MNIST數(shù)據(jù)集是目前應(yīng)用最廣泛的手寫字符數(shù)據(jù)集,包含70000張0~9的灰度圖像。其中60000張為訓練樣本,10000張為測試樣本。在使用圖像樣本進行訓練之前要進行歸一化處理將圖像數(shù)據(jù)限制在一定的范圍,這樣能在訓練時收斂得更快。而對于標簽數(shù)據(jù)進行獨熱編碼方便計算交叉熵。
訓練時,為了保持分類器和生成器的對抗平衡性,避免分類器過早達到最優(yōu)解使模型無法收斂,將分類器和生成器的更新頻率之比設(shè)置為1∶3。對于有標簽樣本數(shù)量設(shè)置4個不同的值,對比使用不同有標簽樣本數(shù)時圖像識別的準確率。為了保證準確性,通過隨機抽樣同時構(gòu)建10個樣本集進行實驗,實驗結(jié)果取平均值。
5.1.1 MNIST生成樣本
SS-DCGAN模型在訓練MNIST樣本時選擇Adam優(yōu)化器,學習率設(shè)置為0.0001,動量為0.5,批處理量為32。
圖7為SS-DCGAN在MNIST數(shù)據(jù)集上分類器損失函數(shù)隨訓練次數(shù)增加而變化的情況。
圖7 MNIST上d_loss變化趨勢
圖8為SS-DCGAN在MNIST數(shù)據(jù)上生成器損失函數(shù)隨訓練次數(shù)增加而變化的情況。
圖8 MNIST上g_loss變化趨勢
從圖7和圖8中可以看出生成器的損失函數(shù)總體上呈上升趨勢,而分類器的損失函數(shù)呈下降趨勢。在訓練初期兩者曲線較為平滑,隨著訓練次數(shù)的增加,生成器和分類器對抗引起損失函數(shù)曲線大幅振蕩。
圖9為生成樣本隨訓練迭代次數(shù)增加而產(chǎn)生的變化??梢钥吹皆谟柧毘跗谏傻臉颖緸槟:幕叶葓D像,不具備手寫數(shù)字特征。經(jīng)過15次迭代后圖像上有了較為明顯的特征,而在第25次迭代后得到了擬合真實數(shù)據(jù)分布的手寫數(shù)字圖像。
圖9 MNIST生成樣本
5.1.2 MNIST分類結(jié)果
表1為MNIST上SS-DCGAN模型使用20、50、100、200個有標簽樣本的識別率與其他方法使用60000個樣本進行訓練的識別率對比[8,12-14]。模型訓練共耗時4 h 27 min,圖像識別速率為13張/s。
表1 MNIST上各方法識別率對比
在使用200個有標簽樣本訓練時,本文方法的識別率已經(jīng)超過Linear Classifier[14]、KNN[12](K-NearestNeighbor)、ADGM[13](Auxiliary Deep Generative Model)、DCNN[8](Deep Convolutional Neural Networks),略低于C-DCGAN[8](Conditional Deep Convolutional Generative Adversarial Networks)。
C-DCGAN訓練時需要60000個樣本,與之相比,SS-DCGAN雖然識別率略低,但訓練時僅使用200個有標簽樣本,其余樣本為無標簽樣本即可達到與C-DCGAN同水平的識別率,可以節(jié)約大量用于樣本標注的人力與時間。實驗結(jié)果表明,SS-DCGAN解決了有標簽樣本數(shù)量較少時識別效果不佳的問題。
CIFAR-10數(shù)據(jù)集是由 Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集的10分類數(shù)據(jù)集,包含60000張32×32彩色圖像,其中50000張為訓練樣本,10000張為測試樣本。使用數(shù)據(jù)前需要進行歸一化處理將圖像數(shù)據(jù)等比例限制在一定范圍內(nèi)。標簽數(shù)據(jù)需要進行獨熱編碼。使用圖3結(jié)構(gòu)作為生成器,圖4作為分類器組成SS-DCGAN模型,考慮到CIFAR-10數(shù)據(jù)集樣本的特點,分類器和生成器更新頻率之比設(shè)置為1∶4。訓練時有標簽樣本數(shù)量設(shè)置4個不同的值,對比不同有標簽樣本數(shù)時圖像識別的準確率。為了保證準確性,通過隨機抽樣同時構(gòu)建10個樣本集進行實驗,實驗結(jié)果取平均值。
5.2.1 CIFAR-10生成樣本
SS-DCGAN模型在訓練CIFAR-10樣本時選擇Adam優(yōu)化器,學習率設(shè)置為0.0002,動量為0.5,批處理量為32。
圖10和圖11分別表示分類器和生成器損失函數(shù)隨訓練次數(shù)變化。
從圖10和圖11中可以看出生成器和分類器損失函數(shù)前期相對平滑,隨著訓練次數(shù)增加生成器損失函數(shù)呈上升趨勢而分類器損失函數(shù)呈下降趨勢并且兩者都表現(xiàn)出大幅振蕩。
圖10 CIFAR-10上d_loss變化趨勢
圖11 CIFAR-10上g_loss變化趨勢
圖12為生成樣本隨訓練迭代次數(shù)增加而產(chǎn)生的變化。在訓練初期生成的樣本為模糊的彩色條紋,在第100次迭代后圖像上有了較為明顯的特征,而在第300次迭代后得到了擬合真實數(shù)據(jù)分布的圖像樣本。
圖12 CIFAR-10生成樣本
5.2.2 CIFAR-10分類結(jié)果
表2為CIFAR-10上SS-DCGAN模型使用1000、2000、4000、8000個有標簽樣本訓練的識別率與1L K-means(1 Layer K-means)、3L K-means(3 Layer K-means)、Cudaconvnet、C-DCGAN((Conditional Deep Convolutional Generative Adversarial Networks))、VI K-means(View Invariant K-means)五種方法[8,15-16]使用50000個樣本進行訓練的識別率對比。模型訓練共耗時5 h 43 min,圖像識別速率為9張/s。
由表2可知,在使用4000個有標簽樣本訓練時,SS-DCGAN的識別率已經(jīng)超過其他圖像識別方法。實驗結(jié)果表明,該方法有效解決了有標簽樣本數(shù)量較少的情況時識別效果不佳的問題。
結(jié)合半監(jiān)督生成對抗網(wǎng)絡(luò)模型和深度卷積生成對抗網(wǎng)絡(luò)模型提出半監(jiān)督深度生成對抗網(wǎng)絡(luò)模型(SS-DCGAN)并抽取分類器部分用于圖像識別。SS-DCGAN模型使用有標簽樣本和無標簽樣本進行訓練,根據(jù)真實數(shù)據(jù)的整體分布,模型利用深度卷積網(wǎng)絡(luò)強大的特征提取能力生成擬合真實數(shù)據(jù)分布的偽樣本,這些偽樣本和真實數(shù)據(jù)共同輸入并訓練分類器,提升了識別率。將SS-DCGAN模型用于圖像識別,并在MNIST和CIFAR-10兩個公開數(shù)據(jù)集上進行了實驗。實驗結(jié)果表明,SS-DCGAN模型僅用少量有標簽樣本即達到了很高的識別率,有效解決了有標簽樣本數(shù)量較少的情況時識別效果不佳的問題。