馬艷龍, 馬宏斌, 王英麗
(黑龍江大學(xué) 電子工程學(xué)院, 哈爾濱 150080)
近年來,圖像識別與分類技術(shù)在計算機(jī)視覺、模式識別、機(jī)器學(xué)習(xí)和深度學(xué)習(xí)等研究領(lǐng)域都是熱點(diǎn)問題[1]。目前圖像識別主要有兩種方法,基于統(tǒng)計算法的圖像識別方法和深度學(xué)習(xí)的圖像識別方法?;诮y(tǒng)計的圖像識別主要有最近鄰距離等方法,此方法需要對圖像進(jìn)行較為復(fù)雜的預(yù)處理,并且在一定程度上忽略了圖像的語義信息,存在圖像識別準(zhǔn)確度較低的問題[2]?;谏疃葘W(xué)習(xí)的圖像識別主要利用卷積神經(jīng)網(wǎng)絡(luò)、循環(huán)神經(jīng)網(wǎng)絡(luò)、膠囊網(wǎng)絡(luò)和深度置信網(wǎng)絡(luò)提取圖片的高層語義特征信息并對圖像進(jìn)行分類識別[3-5]。
隨著圖像識別技術(shù)的飛速發(fā)展,基于深度學(xué)習(xí)的圖像識別方法研究越來越多。Zhong等利用動態(tài)貝葉斯網(wǎng)絡(luò)(Dynamic Bayesian network,DBN)對高光譜遙感圖像進(jìn)行識別,但是DBN的參數(shù)較多,并且對圖像高層次的特征提取能力較弱,在模型訓(xùn)練上存在一定程度的過擬合現(xiàn)象,造成分類精度較低[6]。Mou等改進(jìn)了循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent neural network, RNN),能有效地對輸入的序列數(shù)據(jù)進(jìn)行分析,并進(jìn)行識別,容易存在梯度消失或者梯度爆炸的現(xiàn)象[7]。Deng等提出改進(jìn)的膠囊網(wǎng)絡(luò)對高光譜成像(Hyperspectral image, HSI)進(jìn)行識別,雖然網(wǎng)絡(luò)結(jié)構(gòu)較簡單,但與其他深度學(xué)習(xí)方法相比準(zhǔn)確率仍然較低[8]。卷積神經(jīng)網(wǎng)絡(luò)廣泛地應(yīng)用于圖像識別,并具有較高的識別準(zhǔn)確率,容易出現(xiàn)梯度消失的問題[9]。
2014年,Goodfellow等提出生成式對抗網(wǎng)絡(luò),首次研究了生成式深度學(xué)習(xí)模型[10]。生成對抗網(wǎng)絡(luò)是一種非常優(yōu)秀的生成模型,其創(chuàng)新之處在于將對抗博弈的思想引入其中,通過反向傳播算法實(shí)現(xiàn)端到端的訓(xùn)練,極大地促進(jìn)了生成模型的發(fā)展。GAN可以分為生成器G和判別器D兩部分。生成器主要是學(xué)習(xí)真實(shí)數(shù)據(jù)的分布,生成和真實(shí)數(shù)據(jù)近乎相似的生成樣本,達(dá)到欺騙判別器的目的。判別器盡可能地將真實(shí)數(shù)據(jù)與生成數(shù)據(jù)區(qū)分開,達(dá)到明辨真假的目的。二者在博弈對抗的過程中不斷優(yōu)化自己,最終達(dá)到納什均衡狀態(tài)。該模型在機(jī)器視覺領(lǐng)域得到了較為廣泛的應(yīng)用,通過生成較為逼真的圖像作為數(shù)據(jù)集的擴(kuò)充。目前,GAN主要用于生成超分辨率的圖像[11]、實(shí)現(xiàn)圖像間的翻譯[12]、語音增強(qiáng)[13]、文本生成[14]和由文本生成圖像[15]等領(lǐng)域,在圖像識別領(lǐng)域的應(yīng)用卻不多。
基于上述問題,本文進(jìn)行基于生成對抗網(wǎng)絡(luò)圖像識別方法的研究。在WGAN的基礎(chǔ)上添加條件使其擴(kuò)展成為條件模型,并且用卷積神經(jīng)網(wǎng)絡(luò)代替WGAN中的生成器和判別器,提出了條件生成對抗網(wǎng)絡(luò)CD-WGAN模型并用于圖像識別。實(shí)驗(yàn)結(jié)果表明,該模型具有更高的識別準(zhǔn)確率。
圖1 GAN模型結(jié)構(gòu)圖Fig.1 Structure diagram of GAN model
GAN的模型結(jié)構(gòu)如圖1所示,隨機(jī)噪聲z通過生成器G生成和真實(shí)數(shù)據(jù)相似度極高的偽數(shù)據(jù),判別器D則判斷輸入的數(shù)據(jù)是真實(shí)的數(shù)據(jù)還是生成器生成的數(shù)據(jù)。
GAN模型訓(xùn)練采用交替訓(xùn)練的方式。在訓(xùn)練判別器D時,固定生成器G,訓(xùn)練判別器是希望使目標(biāo)函數(shù)最小化。而在訓(xùn)練生成器G時,則固定判別器D,隨機(jī)噪聲z服從先驗(yàn)分布Pz并通過生成器生成偽數(shù)據(jù),訓(xùn)練生成器是希望使目標(biāo)函數(shù)最大化。式(1)為GAN的目標(biāo)函數(shù),其中D(x)表示判別器D判斷真實(shí)數(shù)據(jù)為真的概率,D(G(z))表示判別器D判斷生成數(shù)據(jù)為真的概率。對于判別器而言,人們希望正確區(qū)分輸入數(shù)據(jù)的真假,所以它希望來自真實(shí)數(shù)據(jù)的判別值D(x)越大越好,來自生成數(shù)據(jù)的判別值D(G(z))越小越好。對于生成器而言,它希望生成的數(shù)據(jù)更加真實(shí)以達(dá)到騙過判別器的目的,所以它希望判別值D(G(z))越大越好,這樣生成的數(shù)據(jù)更接近于真實(shí)數(shù)據(jù)的分布。生成器和判別器網(wǎng)絡(luò)通過不斷地對抗訓(xùn)練,最終達(dá)到納什均衡,生成器生成的圖片越來越趨近于真實(shí)圖片。判別器無法判斷出輸入數(shù)據(jù)的真假,使得判別器的準(zhǔn)確率穩(wěn)定在1/2上,也就是判別器只能對訓(xùn)練樣本進(jìn)行0或1的隨機(jī)猜測。
V(G,D) =Ex~Pdata[logD(x)]+Ez~Pz[log(1-D(G(z)))]
(1)
GAN自從被提出以后,學(xué)術(shù)界以及各大互聯(lián)網(wǎng)巨頭開始致力于GAN的深入研究,GAN一度成為深度學(xué)習(xí)領(lǐng)域最熱門的研究對象。與傳統(tǒng)的生成模型不同,GAN最大的優(yōu)勢在于借鑒了博弈論思想,在模型中引入了對抗機(jī)制,通過對抗訓(xùn)練的方式來擬合真實(shí)數(shù)據(jù)的分布,生成和真實(shí)數(shù)據(jù)近乎一致的圖像,最終達(dá)到難以區(qū)分真假的目的。但是原始的GAN會存在許多問題,如在訓(xùn)練GAN時,納什均衡的狀態(tài)只是理論上可以實(shí)現(xiàn),但是在實(shí)際訓(xùn)練時,GAN是十分不穩(wěn)定的,在訓(xùn)練時較難收斂。所以,WGAN模型代替?zhèn)鹘y(tǒng)的GAN模型可以有效地解決GAN訓(xùn)練不穩(wěn)定的問題。
為了使GAN在訓(xùn)練時較容易收斂,Martin等使用Wasserstein距離來衡量真實(shí)分布與生成分布之間的差距[13],取代了原始GAN使用JS散度來表示生成分布與真實(shí)分布之間的距離的方式,理論上可以解決原始GAN存在的梯度消失問題,該距離具體的定義為:
(2)
將Wasserstein距離加入到GAN中,并且通過形式變換,得到:
(3)
式中:在fw(x)的Lipschitz函數(shù)‖fw‖L≤K時,滿足條件的fw(x)均能取到Ex~Pr,[fw(x)]-Ex~Pg[Fw(x)]的上界,最終與K相除,可求出近似解如式(4)所示:
(4)
構(gòu)建判別器網(wǎng)絡(luò)fw,并且規(guī)定判別器網(wǎng)絡(luò)的參數(shù)w有界,則可使得[fw(x)]-Ex~Pg[Fw(x)]盡可能取到最大值,此時[fw(x)]-Ex~Pg[Fw(x)]就近似等于生成圖片分布與真實(shí)圖片分布之間的Wasserstein距離。與此同時,生成器網(wǎng)絡(luò)也要使Wasserstein距離最小化,即讓[fw(x)]-Ex~Pg[Fw(x)]最小。因此,GAN判別器的損失函數(shù)就變成[fw(x)]-Ex~Pg[Fw(x)],而生成器的損失函數(shù)變成-Ex~Pg[fw(x)]。這樣生成器和判別器的訓(xùn)練就達(dá)到全局最優(yōu),并解決了原始GAN在對抗訓(xùn)練中較難收斂的問題。本文在WGAN的基礎(chǔ)上,在生成器中加入限制條件,構(gòu)成條件生成對抗網(wǎng)絡(luò)。
將條件c與隨機(jī)噪聲z一起輸入到生成器中,這樣就把原始的GAN模型擴(kuò)展成CGAN條件模型[14]。如果加入的條件c代表類別,則可以通過控制輸入的條件c來控制生成特定類別的圖片。
圖2 CGAN模型結(jié)構(gòu)圖
CGAN的模型結(jié)構(gòu)如圖2所示,真實(shí)數(shù)據(jù)x和條件c同時輸入到生成器中得到生成數(shù)據(jù)。判別器的輸入除了包含真實(shí)數(shù)據(jù)和生成數(shù)據(jù),還加入了條件c,判別器除了要區(qū)分生成器生成的數(shù)據(jù)以及真實(shí)數(shù)據(jù),還需判別輸入的數(shù)據(jù)是否和條件c相匹配。因此,可得CGAN的目標(biāo)函數(shù)如式(5)所示。式中D(x|c)表示真實(shí)圖片與條件c輸入到判別器中,輸出結(jié)果為真的概率,G(x|c)表示隨機(jī)噪聲與條件c輸入到生成器中得到的生成樣本,D(G(x|c))表示判別器對G(x|c)的判別結(jié)果為真的概率。
V(G,D) =Ex~Pdata[logD(x|c)]+Ez~Pz[log(1-D(G(z|c)))]
(5)
在WGAN的基礎(chǔ)上,其用卷積神經(jīng)網(wǎng)絡(luò)代替生成器和判別器,并在其生成器中加入限制條件,最終構(gòu)成了有條件的生成對抗網(wǎng)絡(luò)CD-WGAN模型。在訓(xùn)練的過程中,生成器和判別器不斷對抗,經(jīng)過多次迭代以后,模型達(dá)到收斂狀態(tài)。最后,將訓(xùn)練好的CD-WGAN模型的判別器提取出來并進(jìn)行微調(diào),在其最后一層加入Softmax后形成新的網(wǎng)絡(luò),用于圖像識別。
圖3 CD-WGAN模型生成器結(jié)構(gòu)圖
CD-WGAN的生成器模型如圖3所示,生成器的輸入維度分別為40和2的隨機(jī)噪聲與標(biāo)簽,通過Concat操作,變?yōu)?2維的向量輸入到生成器中,經(jīng)過全連接層,然后通過維度轉(zhuǎn)換變成(6,6,64)的張量,經(jīng)過2個核大小為2×2、步長為2的轉(zhuǎn)置卷積和4個核大小為2×2、步長為1的轉(zhuǎn)置卷積,得到(24,24,32)的張量。(24,24,32)的張量與(1,1,2)的張量Concat得到大小為(24,24,34)的張量,最后經(jīng)過一層Sigmod層,得到(24,24,1)的張量,即為生成的圖片。
圖4 CD-WGAN模型判別器結(jié)構(gòu)圖
CD-WGAN的判別器模型如圖4所示,判別器的輸入維度大小為(24,24,1)的圖片,經(jīng)過7層核大小為2×2、步長為1的卷積層后,得到大小為(6,6,64)的張量,通過Reshape操作并且通過一層全連接層,輸出512維的向量,最后通過一層全連接層輸出判別結(jié)果。
CD-WGAN的圖像識別模型如圖5所示。在訓(xùn)練過程中,生成器不斷學(xué)習(xí)輸入圖像的特征,訓(xùn)練結(jié)束以后只需要對訓(xùn)練好的判別器進(jìn)行微調(diào)即可用于圖像識別。在模型訓(xùn)練的過程中判別器的輸出為一維,表示判別器判斷輸入樣本是真實(shí)樣本還是生成樣本,這個結(jié)果對于圖像識別沒有任何意義。所以需要對判別器進(jìn)行微調(diào),才能進(jìn)一步用于圖像識別。將判別器的最后一層替換成維數(shù)為n的全連接層(n為圖像識別的總類別數(shù)),然后添加Softmax層,經(jīng)過Softmax層輸出每個標(biāo)簽值的概率。網(wǎng)絡(luò)結(jié)構(gòu)調(diào)整以后,將原訓(xùn)練樣本集輸入分類器網(wǎng)絡(luò),使用Adam算法優(yōu)化最小化損失函數(shù),進(jìn)一步得到可用于圖像識別的模型。
本文的硬件平臺采用32 GB運(yùn)行內(nèi)存,Inter(R) Core(TM) i7-7700k CPU@4.2 GHZ處理器,NVIDIA GeForce GTX 1080GPU。本文的軟件平臺基于Windows平臺搭建的深度學(xué)習(xí)框架Pytorch,使用Pytorch框架訓(xùn)練模型進(jìn)行實(shí)驗(yàn)對比。分別在MNIST、CIFAR-10和CelebA數(shù)據(jù)集上進(jìn)行了試驗(yàn),并與其他的主流圖像識別方法對比,驗(yàn)證所提出模型的有效性。
MNIST數(shù)據(jù)集來自美國國家標(biāo)準(zhǔn)與技術(shù)研究所(MNIST),數(shù)據(jù)集包含70 000張黑白圖片,其中有60 000張訓(xùn)練圖片和10 000張測試圖片,均為手寫數(shù)字圖片,每一張圖片都表示了0~9的任意一個數(shù)字。為了降低圖像噪聲且優(yōu)化訓(xùn)練模型的時間,訓(xùn)練前進(jìn)行歸一化處理。Loss使用交叉熵?fù)p失函數(shù)。為了避免梯度爆炸和梯度消失,生成器的最后一層激活函數(shù)為Sigmoid函數(shù),其他層的激活函數(shù)為Relu,判別器的激活函數(shù)為LeakyRelu函數(shù)。在進(jìn)行實(shí)驗(yàn)時,判別器的學(xué)習(xí)率設(shè)置為0.000 4,生成器的學(xué)習(xí)率為0.000 1,優(yōu)化器為Adam。生成器和判別器的Loss曲線分別如圖6和圖7所示。可以看出,在訓(xùn)練初期,損失函數(shù)曲線較為平滑,隨著迭代次數(shù)的增加,生成器和判別器由于對抗引起損失函數(shù)的大幅振蕩。
圖6 MNIST上生成器Loss曲線
CIFAR-10數(shù)據(jù)集包含60 000張彩色圖片,一共10類。CIFAR-10的數(shù)據(jù)的每類圖片有6 000張,其中5 000張為訓(xùn)練集圖片,1 000張為測試集圖片。為了降低圖像噪聲且優(yōu)化訓(xùn)練模型的時間,訓(xùn)練前進(jìn)行歸一化處理。Loss使用交叉熵?fù)p失函數(shù)。為了避免梯度爆炸和梯度消失,生成器的最后一層激活函數(shù)為Sigmoid函數(shù),其他層的激活函數(shù)為Relu,判別器的激活函數(shù)為LeakyRelu函數(shù)。在進(jìn)行實(shí)驗(yàn)時,判別器的學(xué)習(xí)率設(shè)置為0.000 4,生成器的學(xué)習(xí)率為0.000 1,優(yōu)化器為Adam。生成器和判別器的Loss曲線分別如圖8和圖9所示??梢钥闯觯谟?xùn)練初期,損失函數(shù)曲線較為平滑,隨著迭代次數(shù)的增加,生成器和判別器由于對抗引起損失函數(shù)的大幅振蕩。與MNIST數(shù)據(jù)集實(shí)驗(yàn)相比,由于CIFAR-10數(shù)據(jù)集相對復(fù)雜,所以表現(xiàn)效果略差。
圖8 MNIST上生成器Loss曲線
CelebA為名人人臉數(shù)據(jù)集,包含了10 177個名人的202 599張人臉的圖片。為了降低圖像噪聲且優(yōu)化訓(xùn)練模型的時間,訓(xùn)練前進(jìn)行歸一化處理。Loss使用交叉熵?fù)p失函數(shù)。為了避免梯度爆炸和梯度消失,生成器的最后一層激活函數(shù)為Sigmoid函數(shù),其他層的激活函數(shù)為Relu,判別器的激活函數(shù)為LeakyRelu函數(shù)。判別器的學(xué)習(xí)率設(shè)置為0.000 8,生成器的學(xué)習(xí)率為0.000 1,優(yōu)化器為Adam。生成器和判別器的Loss曲線分別如圖10和圖11所示。可以看出,在訓(xùn)練初期,損失函數(shù)曲線較為平滑,隨著迭代次數(shù)的增加,生成器和判別器由于對抗引起損失函數(shù)的大幅振蕩。但是相對于MNIST和CIFAR-10數(shù)據(jù)集,CelebA人臉數(shù)據(jù)集的表現(xiàn)結(jié)果略差一些,主要原因在于人臉圖像比較復(fù)雜,并且受很多因素影響,所以進(jìn)行模型訓(xùn)練和圖像識別也較為困難。
圖10 MNIST上生成器Loss曲線
本實(shí)驗(yàn)采用了MNIST、CIFAR-10和CelebA三個不同的數(shù)據(jù)集,MNIST數(shù)據(jù)集里均為黑白數(shù)字圖片,而CIFAR-10和celebA為彩色圖片。MNIST和CIFAR-10為數(shù)字和物體的圖片,而CelebA為人臉圖片。采用這三個風(fēng)格和類型較大的數(shù)據(jù)集作為本文實(shí)驗(yàn)的數(shù)據(jù)集,能夠驗(yàn)證本文提出的模型的有效性。各個模型在不同數(shù)據(jù)集下的結(jié)果如表1所示。
表1 不同模型在各個數(shù)據(jù)集下的結(jié)果
本文基于生成對抗網(wǎng)絡(luò)的對抗博弈思想,將WGAN擴(kuò)展為條件模型,并將其生成器和判別器都替換成卷積神經(jīng)網(wǎng)絡(luò),提出了基于CD-WGAN的圖像識別方法。 此圖像識別方法在三個不同的數(shù)據(jù)集上驗(yàn)證了實(shí)驗(yàn)的有效性,與基于Linear classifier、SVM和CNN的圖像識別方法相比,CD-WGAN模型具有更高的圖片識別準(zhǔn)確率。這較好地解決了主流的圖像識別方法存在的識別準(zhǔn)確率較低的問題,具有一定的實(shí)用價值。