張目飛,李 廷,蘇 鵬
(1 浪潮云信息技術(shù)股份公司 服務(wù)研發(fā)部,濟(jì)南 250000;2 山東浪潮新基建科技有限公司,濟(jì)南 250000)
隨著個人智能設(shè)備和圖像相關(guān)應(yīng)用的普及,會產(chǎn)生大量的圖像數(shù)據(jù),如何高效、合理地對這些圖像數(shù)據(jù)進(jìn)行合理的分類是一項(xiàng)技術(shù)難題。在過去的幾年中,深度神經(jīng)網(wǎng)絡(luò)(DNN)在計(jì)算機(jī)視覺和模式識別任務(wù)中,如:圖像分類、語義分割、對象檢測應(yīng)用廣泛。卷積神經(jīng)網(wǎng)絡(luò)中的卷積層能夠捕獲圖像的局部特征,以獲得與輸入維度相似的空間表示,使用全連接層和softmax 分類層生成概率表示,來達(dá)到分類效果[1]。He 等[2]提出了深度殘差網(wǎng)絡(luò)ResNet34,引入了殘差結(jié)構(gòu),可以更好地學(xué)習(xí)殘差信息,并在后續(xù)層中使用這些殘差信息,提高了圖像分類的性能,為深度學(xué)習(xí)領(lǐng)域帶來了新的思路和方法。
許多基于深度神經(jīng)網(wǎng)絡(luò),在網(wǎng)絡(luò)學(xué)習(xí)過程中添加注意力機(jī)制來獲得圖像中感興趣區(qū)域,通過選擇給定輸入的特征通道、區(qū)域來自動提取相關(guān)特征[3]。Woo 等[4]將注意力機(jī) 制模塊集成 到CNN中,提高網(wǎng)絡(luò)的特征表達(dá)能力,從而提高了圖像分類的準(zhǔn)確率;Wang[5]提出了殘差注意網(wǎng)絡(luò),殘差結(jié)構(gòu)可以使網(wǎng)絡(luò)更好地學(xué)習(xí)圖像中的特征,通過添加注意力模塊來學(xué)習(xí)圖像中的局部區(qū)域特征;Park 等[6]提出了一種新的注意力機(jī)制,可以在空間和通道維度上同時進(jìn)行特征加權(quán),更加準(zhǔn)確地捕捉到圖像中的重要信息;Xi 等[7]提出用殘差注意模塊進(jìn)行特征提取,以增強(qiáng)分類任務(wù)中的關(guān)鍵特征,抑制無用的特征;Liang[8]提出將自下而上和自上而下的前饋?zhàn)⒁饬埐钅K用于圖像分類。以上工作說明殘差結(jié)構(gòu)和注意力機(jī)制都可以幫助模型更好地學(xué)習(xí)圖像特征,提高圖像分類的準(zhǔn)確性。
隨著數(shù)據(jù)集規(guī)模的增大和類別的增多,訓(xùn)練一個高準(zhǔn)確率的分類模型變得越來越困難。傳統(tǒng)的數(shù)據(jù)增強(qiáng)方法對原始圖像進(jìn)行幾何變換或者對圖像進(jìn)行隨機(jī)擾動,雖然可以增加數(shù)據(jù)集的樣本量,提高分類模型的準(zhǔn)確率,但是這些方法無法生成新的數(shù)據(jù)分布。而生成網(wǎng)絡(luò)是一種可以學(xué)習(xí)數(shù)據(jù)分布的生成模型,可以生成新的樣本,從而擴(kuò)大數(shù)據(jù)集并且增加數(shù)據(jù)多樣性,從而可以提高分類模型的泛化性[9]。因此,本文提出一個深度殘差注意力生成網(wǎng)絡(luò)來生成圖像數(shù)據(jù),對數(shù)據(jù)進(jìn)行必要的數(shù)據(jù)增強(qiáng),利用ResNet34 網(wǎng)絡(luò)進(jìn)行圖像分類。
本文提出了一個深度殘差注意力生成網(wǎng)絡(luò)模型用于圖像數(shù)據(jù)增強(qiáng),主要結(jié)構(gòu)包括生成器、判別器和殘差注意力模塊。生成器包含4 個反卷積層(DConv)和3 個殘差注意力模型(SPAM),殘差注意力模型能夠?qū)D像的重點(diǎn)區(qū)域進(jìn)行特別關(guān)注,以生成高質(zhì)量的圖像,在生成器的最后一層使用Tanh 函數(shù)將數(shù)據(jù)映射到[-1,1]的區(qū)間內(nèi);判別器包括4 個卷積層(Conv),能夠提取圖像細(xì)節(jié)特征。深度殘差注意力生成網(wǎng)絡(luò)模型結(jié)構(gòu)如圖1 所示。
圖1 深度殘差注意力生成網(wǎng)絡(luò)模型結(jié)構(gòu)Fig.1 Deep residual attention generation network model
生成網(wǎng)絡(luò)由生成器和判別器組成。生成器將隨機(jī)向量Z作為輸入,學(xué)習(xí)真實(shí)數(shù)據(jù)分布p(x)從而合成逼真的圖像;判別器區(qū)分生成的圖像與真實(shí)的圖像,其輸出表示從真實(shí)分布p(x)提取樣本y的概率。生成網(wǎng)絡(luò)的最終目標(biāo)是讓生成器生成和真實(shí)圖像相同的數(shù)據(jù)分布,而判別器無法判定圖像為真實(shí)圖像還是生成圖像,達(dá)到一個納什平衡。在生成器和判別器相互博弈的過程中,生成網(wǎng)絡(luò)的目標(biāo)函數(shù)定義為公式(1):
其中,p(x)表示真實(shí)數(shù)據(jù)分布;p(z)表示生成數(shù)據(jù)分布;D(x)表示判別器運(yùn)算;G(z)表示生成器運(yùn)算。
本文隨機(jī)選取Z=100 維的隨機(jī)數(shù)據(jù)作為生成器的輸入,經(jīng)過生成器生成圖像;判別器網(wǎng)絡(luò)的輸入為生成圖像和真實(shí)圖像,判別器網(wǎng)絡(luò)指導(dǎo)生成器合成圖像,鼓勵生成器捕捉更為精細(xì)的特征細(xì)節(jié),使得生成器生成的圖像和真實(shí)圖像難以區(qū)分。
殘差注意力模型使具有相似特征的區(qū)域相互增強(qiáng),以突出全局視野中的感興趣區(qū)域,殘差注意力模型如圖2 所示。通過sigmoid 函數(shù)可以得到一個[0,1]的系數(shù),給每個通道或空間分配不同的權(quán)重,可以給每個特征圖分配不同的重要程度。
圖2 殘差注意力模型Fig.2 Residual attention model
本文設(shè)C × H × W為殘差注意力模型的輸入,C為特征圖的數(shù)量,H和W分別表示為圖像的高度和寬度;通過卷積和批量歸一化運(yùn)算對輸入的特征進(jìn)行處理,利用Sigmoid函數(shù)得到空間注意系數(shù)S;將輸入的特征圖和通過注意力模型得到的特征圖利用殘差結(jié)構(gòu)進(jìn)行融合,得到最終的殘差空間注意力特征表示,公式(2)和公式(3):
其中,X表示空間注意模型的輸入,Conv 表示卷積運(yùn)算。
首先,對輸入圖像進(jìn)行數(shù)據(jù)預(yù)處理,主要包括:將圖像裁剪為28×28 的大小,并進(jìn)行隨機(jī)旋轉(zhuǎn)和對比度增強(qiáng);其次,將預(yù)處理的數(shù)據(jù)送入到深度殘差注意力生成網(wǎng)絡(luò)中進(jìn)行數(shù)據(jù)增強(qiáng)。深度殘差注意力生成網(wǎng)絡(luò)通過學(xué)習(xí)圖像不變性特征,合成高質(zhì)量的數(shù)據(jù),注意力機(jī)制對圖像的感興趣區(qū)域進(jìn)行重點(diǎn)關(guān)注;生成器通過學(xué)習(xí)隨機(jī)數(shù)據(jù)來生成感興趣的圖像分布,判別器學(xué)習(xí)真實(shí)樣本的分布,辨別生成器生成的圖像;同時訓(xùn)練生成器和判別器,促使兩者競爭,在理想情況下,生成器可以生成近似于真實(shí)的圖像數(shù)據(jù),而判別器不能將真實(shí)圖像與生成圖像區(qū)分,從而達(dá)到納什平衡,達(dá)到數(shù)據(jù)增強(qiáng)的目的;最后,利用ResNet34 網(wǎng)絡(luò)對增強(qiáng)的圖像數(shù)據(jù)進(jìn)行分類。
本文使用PyTorch 深度學(xué)習(xí)框架來訓(xùn)練模型,GPU 為NVIDIA Tesla V100,顯存為32 GB。采用Adam 算法優(yōu)化損失函數(shù),采用小批量樣本的方式訓(xùn)練深度學(xué)習(xí)模型,batch_size 設(shè)置為64,在訓(xùn)練的過程中采用固定步長策略調(diào)整學(xué)習(xí)率,初始學(xué)習(xí)率設(shè)置為0.000 1,gamma 值為0.85,L2 正則化系數(shù)設(shè)置為0.000 1,迭代次數(shù)為50 000 次。
本文采用的數(shù)據(jù)集為MNIST 數(shù)據(jù)集和cirfar10數(shù)據(jù)集。MNIST 數(shù)據(jù)集一共有70 000張圖片,其中60 000 張作為訓(xùn)練集,10 000 張作為測試集,每張圖片由28×28 的0~9 的手寫數(shù)字圖片組成;cirfar10數(shù)據(jù)集由60 000 張32×32 的彩色圖片組成,一共有十個類別,每個類別有6 000 張圖片,其中50 000 張圖片作為訓(xùn)練集,10 000 張圖片作為測試集。
使用深度殘差注意力生成網(wǎng)絡(luò)分別對MNIST和cirfar10 數(shù)據(jù)集中的圖像進(jìn)行圖像增強(qiáng),使得圖像的特征更加多樣,對MNIST 數(shù)據(jù)集進(jìn)行數(shù)據(jù)增強(qiáng)的效果如圖3 所示,對cirfar10 數(shù)據(jù)進(jìn)行數(shù)據(jù)增強(qiáng)的效果如圖4 所示。
圖3 MNIST 數(shù)據(jù)集數(shù)據(jù)增強(qiáng)的效果Fig.3 Effect of data enhancement of MNIST dataset
圖4 cirfar10 數(shù)據(jù)集數(shù)據(jù)增強(qiáng)的效果Fig.4 Effect of data enhancement on the cirfar10 dataset
從圖3 和圖4 可以看出,使用深度殘差注意力生成網(wǎng)絡(luò)對MNIST 和cirfar10 數(shù)據(jù)集進(jìn)行數(shù)據(jù)增強(qiáng),具有很強(qiáng)的視覺可讀性,同時也具有較清晰的紋理特征,實(shí)現(xiàn)了數(shù)據(jù)增強(qiáng),擴(kuò)充了數(shù)據(jù)集。
為了驗(yàn)證本文模型數(shù)據(jù)增強(qiáng)后的MNIST 以及cirfar10 數(shù)據(jù)在分類方面的效果,選擇 CNN、ResNet18、ResNet34、ResNet50 和ResNet101 作為分類網(wǎng)絡(luò)做對比實(shí)驗(yàn)。第一組測試增強(qiáng)數(shù)據(jù)的分類準(zhǔn)確率;第二組,測試原始數(shù)據(jù)的分類準(zhǔn)確率;第三組,將增強(qiáng)數(shù)據(jù)和原始數(shù)據(jù)各拿出50%組成新的數(shù)據(jù)集進(jìn)行測試,實(shí)驗(yàn)結(jié)果見表1 和表2。
表1 MNIST 數(shù)據(jù)集分類準(zhǔn)確率實(shí)驗(yàn)結(jié)果(%)Tab.1 Experimental results of classification accuracy of MNIST dataset(%)
通過表1 和表2 可以看出,使用深度殘差注意力生成網(wǎng)絡(luò)進(jìn)行數(shù)據(jù)增強(qiáng)能夠提高數(shù)據(jù)集的分類效果,證明本文提出的模型是切實(shí)有效的。利用本文模型進(jìn)行數(shù)據(jù)增強(qiáng)的數(shù)據(jù)和原始數(shù)據(jù)相結(jié)合,在MNIST 數(shù)據(jù)集上達(dá)到了98.95% 的準(zhǔn)確率,在cirfar10 數(shù)據(jù)集上達(dá)到了92.68%的準(zhǔn)確率。
表2 cirfar10 數(shù)據(jù)集分類準(zhǔn)確率實(shí)驗(yàn)結(jié)果(%)Tab.2 Experimental results of classification accuracy(%)for the cirfar10 dataset
本文提出了一種深度殘差注意力生成網(wǎng)絡(luò)用于數(shù)據(jù)增強(qiáng),從而提高分類的準(zhǔn)確率。實(shí)驗(yàn)結(jié)果證明,該模型在MNIST 數(shù)據(jù)集上獲得了98.95%的準(zhǔn)確率,準(zhǔn)確率提高了0.93 個百分點(diǎn);在cirfar10 數(shù)據(jù)集上獲得了92.68%的準(zhǔn)確率,準(zhǔn)確率提高了0.65 個百分點(diǎn)。本文模型的提出,為數(shù)據(jù)增強(qiáng)提供了一種解決思路和方式。