趙 莉,白猛猛,趙亞欣,肖 鋒
(西安工業(yè)大學 計算機科學與工程學院,西安 710021)
圖像域轉換是計算機視覺里一個重要的新研究方向,具有廣闊的應用前景,如衛(wèi)星遙感影像圖像轉換電子地圖圖像,殘缺圖像轉換補全后完整圖像等領域.以往的圖像域轉換基于人工設計的方式進行,包含顏色空間域轉換,小波變換等操作.這些域轉換算法需進行大量人工優(yōu)化參數(shù)才能滿足特定任務的需要.現(xiàn)今隨著深度學習的發(fā)展,生成式對抗網(wǎng)絡(Generative Adversarial Networks,GAN)被提出用來學習特定目標數(shù)據(jù)類型的數(shù)據(jù)分布規(guī)律,從而生成新的目標數(shù)據(jù)類型中的數(shù)據(jù).
GAN是深度學習發(fā)展的一個新的分支,是基于博弈論的一種深度學習網(wǎng)絡.GAN包含兩類網(wǎng)絡一種是含有反卷積層的生成式網(wǎng)絡用于生產(chǎn)數(shù)據(jù),另一類為起鑒別作用的鑒別器網(wǎng)絡,同生成網(wǎng)絡產(chǎn)生博弈.2014年文獻[1]首次提出GAN,此時的GAN通過對其輸入一個隨機噪音信號,可隨機生成不同目標類型數(shù)據(jù)的樣本,此時的GAN只用于學習目標數(shù)據(jù)類型的數(shù)據(jù)分布.在原始的GAN上通過改進鑒別網(wǎng)絡[2]產(chǎn)生的Conditional GAN能夠根據(jù)輸入帶有限制條件的信號生成特定的目標類型數(shù)據(jù),如文獻[3]可根據(jù)輸入的條件不一樣產(chǎn)生不同角度的帶有窗戶的室內(nèi)圖像,文獻[4]可根據(jù)輸入的描述性文字經(jīng)處理后輸入GAN產(chǎn)生與文字描述相對應的圖像等.Conditional GAN的產(chǎn)生為圖像域的轉換成帶來了新的解決思路,即將原始域的圖像作為輸入GAN的條件信號,目標域的圖像作為GAN產(chǎn)生的對象.基于該思路的圖像域轉換,如文獻[5]利用文獻[6]設計的生成網(wǎng)絡結合真假鑒別網(wǎng)絡,對輸入的人臉圖像實現(xiàn)了卡通風格的轉換,并且為無監(jiān)督的訓練方式;文獻[7]利用U-net[8]以及配對鑒別網(wǎng)絡實現(xiàn)了多種應用的圖像域轉換,然而其訓練方式變成了有監(jiān)督的訓練方式;文獻[9]基于原始圖像轉換到目標域圖像再轉換回原始圖像的訓練思路,在多種評價指標略差于pix2pix網(wǎng)絡的情況下實現(xiàn)了網(wǎng)絡的無監(jiān)督訓練.
圖像域轉換問題是將一副圖像從一個域轉換到另一個域,對于該問題可利用GAN先對輸入的原始域圖像進行處理,再根據(jù)目標域圖像產(chǎn)生的內(nèi)部規(guī)律生成與之對應的圖像.為了使圖像域轉換后的結果更加真實,深入分析以上文獻的優(yōu)缺點,文中擬針對圖像域轉換特點提出新的生成網(wǎng)絡、鑒別網(wǎng)絡和損失函數(shù),使生成網(wǎng)絡能夠更好的學習到目標域圖像數(shù)據(jù)的分布,提高圖像域轉換的準確率.
GAN網(wǎng)絡由生成網(wǎng)絡和鑒別網(wǎng)絡兩類網(wǎng)絡組成[10-11],文中的設計的網(wǎng)絡亦遵循該基本原則,如圖1所示.Input source為輸入的原始域圖像,Output為生成的目標域圖像,Target為真實的目標域圖像.F-net和G-net是U-net中的兩部分,前者包含大量的卷積操作用于提取輸入圖像的高維特征,后者則含有大量的反卷積操作用于生成圖像.Pair-net用于判斷輸入的原始域圖像與生成的目標域圖像是否一致,并且同G-net產(chǎn)生博弈產(chǎn)生Pair-loss損失用于更新G-net.圖1中出現(xiàn)的2處F-net為共享權重相同結構的網(wǎng)絡,由于原始域和目標域圖像表現(xiàn)形式不同但擁有本質(zhì)的共同特征,對原始域圖像以及生成的目標域圖像輸入F-net產(chǎn)生的特征圖,利用多層特征圖之間的F-loss損失可用于更新G-net確保其生成的目標域圖像是對應原始域圖像.
圖1 網(wǎng)絡結構圖
文中設計的GAN的整體損失函數(shù)如下:
(1)
式中:LGAN為網(wǎng)絡整體損失函數(shù);α為設置的參數(shù),在最小化LG生成網(wǎng)絡損失函數(shù)時需要最大化配對鑒別網(wǎng)絡LPair的損失,即使生成的圖像盡量騙過配對鑒別網(wǎng)絡,又由于LG中包含LPair故在訓練的過程中需不斷的訓練Pair-net使其鑒別原始域圖像與生成的目標域圖像的能力等到提升,使得兩個網(wǎng)絡交替更新權重相互博弈最終使得彼此的能力都得到提升.式1中LPair損失部分如下:
LPair=Ex,y∈Pdata(x,y)[logPair(x,y)]+
Ex∈Pdata(x,y)[log(1-Pair(x,G(x)))]
(2)
式中:Pdata(x,y)為配對的數(shù)據(jù)集;x為原始域圖像;y為真實的目標域圖像;Pair(x,y)輸入x原始域圖像和其配對的真實目標域圖像y輸出是否配對的預測值;G(x)為輸入x原始域圖像輸出生成的目標域圖像.該部分損失函數(shù)基于交叉熵損失函數(shù)改寫,輸出的值越大表示輸入的目標域圖像越是真實.式(1)中LG損失函數(shù)部分具體如下:
LG=Lpair+aLf1+bLf2+cLf3+dL1
(3)
式中:Lpair為式2中的損失函數(shù);a,b,c,d為參數(shù),L1為真實目標域圖像與G生成的目標域圖像的正則項,公式為
L1=Ey∈Pdata(x,y)(‖y-G(x))‖1)
(4)
對于Lf1、Lf2和Lf3其損失的計算同Lf*一致,如下:
Lf*=Ex∈Pdata(x,y)(d(f*(x)
f*(G(x))))
(5)
其中d( )采用平均平方差(Mean Squared Error,MSE)計算輸入數(shù)據(jù)的損失.Lf1、Lf2和Lf3之間的不同在于f*(x)所選F-net提取的不同層次的圖像高維特征.
文中設計的網(wǎng)絡進行的是SVHN到MNIST數(shù)據(jù)集圖像域轉換,輸入的是SVHN數(shù)據(jù)集輸出其對應的MNIST數(shù)據(jù)集,SVHN和MNIST數(shù)據(jù)集均為0~9數(shù)字,前者是各種街道門牌號數(shù)字后者是手寫數(shù)字.文中將U-net拆分成F-net和G-net如圖1所示,在網(wǎng)絡訓練階段單獨提取F-net部分并在其之后添加softmax層令該層的輸出為0~9,將帶有0~9標簽的SVHN數(shù)據(jù)集輸入F-net進行有監(jiān)督的訓練,保存F-net網(wǎng)絡部分的權重.共享F-net的權重到生成網(wǎng)絡中,即F-net的初始權重為在SVHN數(shù)據(jù)集上已訓練的權重,G-net為隨即初始化權重,Pair-net亦為隨即初始化權重.在進行本文設計的GAN訓練時,采用一次訓練生成網(wǎng)絡多次訓練配對鑒別網(wǎng)絡的方式進行.
U-net的結構如圖2所示,文中將U-net劃分為F-net和G-net兩部分網(wǎng)絡.圖2中F-net的每層由卷積層、Bacth Normalization層以及ReLU激活層組成,第一層輸入的為32*32*3的SVHN數(shù)據(jù),輸出的特征圖大小為16*16數(shù)量為64張,采用的卷積核大小為3*3,F(xiàn)-net之后的每層結構均采用該設計,最后一層的卷積核大小為4*4.圖2中G-net輸入的數(shù)據(jù)張量為1*1*512,前三層包括反卷積層、Bacth Normalization層和ReLU激活層,且第一層的反卷積層的卷積核大小為4*4其余為3*3,最后一層網(wǎng)絡將ReLU激活層換成tanh激活層用于輸出生成的32*32*1數(shù)據(jù)類型的MNIST圖像.
圖2 生成網(wǎng)絡結構圖
數(shù)據(jù)在GAN網(wǎng)路訓練開始時需對于F-net進行預訓練,提取出F-net進行修改后的網(wǎng)絡如圖3所示.圖3中在原有的F-net基礎上增加了softmax層輸出為預測的0~9數(shù)值.F-net的預訓練輸入的是帶有0~9標簽的SVHN數(shù)據(jù)集且每個樣本的像素為32*32,在訓練階段使用的損失函數(shù)為交叉熵損失函數(shù),如下:
(6)
G-net主要由反卷積組成用于數(shù)據(jù)生成,學習的是訓練數(shù)據(jù)集的數(shù)據(jù)分布,根據(jù)F-net提取的圖像高維特征生成目標數(shù)據(jù).F-net的預訓練主要是為了更好的獲取原數(shù)據(jù)集的語義特征,以使得G-net能夠更好的學習數(shù)據(jù)轉化中目標數(shù)據(jù)的分布.而G-net的訓練依賴對抗網(wǎng)絡pair-net,即在對抗訓練的過程中G-net網(wǎng)絡才能有效學習數(shù)據(jù)的分布,因此,無需對G-net進行預訓練.
Pair-net的結構如圖4所示,網(wǎng)絡的輸入為INPUT1和INPUT2,其中INPUT1為32*32*3的SVHN數(shù)據(jù)集的圖像,INPUT2輸入的為32*32*1的MNIST數(shù)據(jù)集圖像或者由生成網(wǎng)絡生成的MNIST數(shù)據(jù)集圖像.Pair-net的輸出為0~1之間的數(shù)值反映兩幅圖像配對的程度,輸出值為1時表示圖像為完全真實的配對,0表示圖像完全不配對.Pair-net的第一層網(wǎng)絡為Concat層用于合并輸入的圖像數(shù)據(jù)生成32*32*4的張量,之后的三層網(wǎng)絡采用3*3大小的卷積核,且每層包含Bacth Normalization層和ReLU層,最后一層將ReLU層置換成Sigmoid層并且卷積核大小改為4*4.Pair-net訓練時將SVHN數(shù)據(jù)集及與其數(shù)字對應的MNIST數(shù)據(jù)集作為完全真實配對的正樣本,將SVHN數(shù)據(jù)集以及由生成網(wǎng)絡產(chǎn)生的與其對應數(shù)字的圖像作為負樣本,Pair-net的損失函數(shù)如式2所示.對于損失函數(shù)的優(yōu)化采用adam方法,Pair-net的初始化同F(xiàn)-net一樣采用哈維爾初始化.
圖4 Pair-net結構圖
生成網(wǎng)絡包含F(xiàn)-net和G-ne兩部分,訓練過程如圖5所示.其中,F(xiàn)-net部分權重的初始化采用在SVHN數(shù)據(jù)集已訓練的權重,G-net含有反卷積層權重采用隨機初始化方法.輸入SVHN數(shù)據(jù)集經(jīng)F-net提取高維特征,再經(jīng)過G-net生成目標域圖像,將生成的圖像再次輸入F-net提取其特征,采用用均值平方差的方法獲得兩個F-net的2,3,4特征圖之間的Lf*損失.G-net產(chǎn)生的MNIST數(shù)據(jù)輸入同其輸入的SVHN數(shù)據(jù)一起輸入Pair-net,產(chǎn)生LPair配對損失.生成的MNIST數(shù)據(jù)同真實對應的MNIST數(shù)據(jù),使用式4計算其L1損失.結合以上損失采用adam的方式進行損失函數(shù)的優(yōu)化,更新G-net網(wǎng)絡的權重.
鑒別網(wǎng)絡pairnet產(chǎn)生的損失pair_loss作用有:
① 確保輸入的數(shù)據(jù)與轉換后的數(shù)據(jù)在語義上一一對應;② 確保轉換后的數(shù)據(jù)能符合目標數(shù)據(jù)的分部.
f1_loss,f2_loo,f3_loss的作用有:
① 在網(wǎng)絡訓練的初期,彌補鑒別網(wǎng)絡不能有效返回gnet學習損失的問題;② 對損失函數(shù)起到正則化的作用,防止過擬合.
圖5 生成網(wǎng)絡訓練圖
試驗環(huán)境的GPU為GTX1080 8G顯存,CPU為Xeon E5-2698 v3 16核主頻2.3 GHz,CUDA 8.0,cudnn v5,采用tensorflow深度學習框架,原始域數(shù)據(jù)集為SVHN,目標域數(shù)據(jù)集為MNIST,圖像轉換的結果如圖6所示.其中,轉換的圖像為截取自SVHN門牌圖像中的數(shù)字,并按其對應數(shù)字轉成MNIST數(shù)據(jù)類型圖像.
由于,SVHN數(shù)據(jù)集轉換MNIST數(shù)據(jù)集為一一對應,為了驗證是否正確轉換對轉換后的MNIST數(shù)據(jù)集進行識別.文中使用文獻[12]中的圖像識別網(wǎng)絡對MNIST數(shù)據(jù)集進行分類訓練,在MNIST測試集上的誤差率為0.23%,用該網(wǎng)絡識別轉換后的MNIST正確性.令配對鑒別網(wǎng)絡α=2,即訓練一次G-net,訓練2次Pair-net,網(wǎng)絡迭代10萬次,設置文中方案不同權重的試驗對比結果見表1.其中方案一選擇的G-net損失函數(shù)為Pair-net反饋回來的配對損失和F-net第2層提取特征圖像的均值平方誤差,SVHN轉換MNIST數(shù)據(jù)后的一一對應的準確率為78.32%.方案二則增加了F-net第3層的特征圖形的均值平方誤差損失,準確率達到81.12%,方案三增加了F-net第4層特征圖像的均值平方誤差損失,準確率達到84.56%.方案四在方案三的基礎上增加了正則項L1損失最終的準確率為86.73%.
輸入相同的SVHN數(shù)據(jù)轉換成MNIST數(shù)據(jù)集,對比可見方案四的轉換結果以及轉換準確率要優(yōu)于其他三種方案.
圖6 試驗結果
表1 試驗對比結果
文中在相同的數(shù)據(jù)集以及使用相同的圖像識別網(wǎng)絡的條件下,文與其他圖像域轉換算法做了實驗對比,結果見表2.選擇3種圖像域轉換方案進行對比,其中第一種方法和和第二種方法并非單純使用GAN算法,轉換效果明顯低于文中方案,第三種使用了GAN但是通過結果可見文中方案亦優(yōu)于該算法.
表2 試驗對比圖
文中基于GAN生成式對抗網(wǎng)絡相互博弈的原則,通過使用Pair-net配對鑒別網(wǎng)絡,F(xiàn)-net產(chǎn)生特征圖像的均值平方誤差損失以及添加的L1正則項,提升了網(wǎng)絡在圖像域轉換的能力.使用SVHN數(shù)據(jù)集作為原始域數(shù)據(jù),MNIST數(shù)據(jù)集作為目標域數(shù)據(jù)進行圖像域的轉換.為了評價轉換效果,利用MNIST數(shù)據(jù)集訓練的識別網(wǎng)絡對轉換后的網(wǎng)絡進行了識別,文中算法的轉換準確率為86.73%,均高于其他三種算法.傳統(tǒng)的深度學習網(wǎng)絡,其網(wǎng)絡權重是訓練學習獲得的,網(wǎng)絡的損失函數(shù)是固定的、人工設計的,需要針對特定目標任務制定損失函數(shù).而生成式對抗網(wǎng)絡利用對抗機制,使得的網(wǎng)絡的損失函數(shù)定義不再固定,而是在網(wǎng)絡訓練的過程中,以黑盒的形式用鑒別網(wǎng)絡不斷的擬合符合當前任務的損失函數(shù).這種機制雖然實現(xiàn)了網(wǎng)絡權重的學習,損失函數(shù)的學習,但是這也使得模型有效訓練成為難點,因此,目前需額外增加固定的損失函數(shù),確保模型訓練的穩(wěn)定性,在后續(xù)工作中完善網(wǎng)絡轉換能力,提高轉換效果.