王宥翔
(鄭州中糧科研設(shè)計(jì)院 電氣所,河南 鄭州 450000)
超分辨率(Super Resolution)通過(guò)硬件或軟件提高原有圖像的分辨率。圖像超分辨率研究大體分為3類:基于插值、基于重建、基于學(xué)習(xí);在技術(shù)層面則分為超分辨率復(fù)原和超分辨率重建。超分辨率重建是通過(guò)一系列低分辨率的圖像生成一幅高分辨率的圖像過(guò)程。
超分辨率重建是用時(shí)間帶寬換取空間分辨率,實(shí)現(xiàn)時(shí)間分辨率轉(zhuǎn)換為空間分辨率。超分辨率重建各種算法的區(qū)別主要在于網(wǎng)絡(luò)構(gòu)建的思路不同,而相同思路建構(gòu)的網(wǎng)絡(luò)也存在細(xì)微的差別。超分辨率重建大部分使用單純的卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks,CNN)完成任務(wù),但是CNN網(wǎng)絡(luò)在池化層和平移不變性方面容易出現(xiàn)問(wèn)題,文獻(xiàn)[1]揭示并分析了卷積神經(jīng)網(wǎng)絡(luò)在變換兩種空間表征(笛卡爾空間坐標(biāo)和像素空間坐標(biāo))時(shí)的常見(jiàn)缺陷。本文基于深度學(xué)習(xí)的方案,選擇更為優(yōu)秀的生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Networks,GAN)進(jìn)行超分辨率重建。
生成模型泛指在給定一些隱含參數(shù)的條件下隨機(jī)生成觀測(cè)數(shù)據(jù)的模型,主要分為兩類:一是建立有確切數(shù)據(jù)的分布函數(shù)模型;二是在無(wú)需完全明確數(shù)據(jù)分布函數(shù)模型的條件下直接生成一個(gè)新樣本[2],如GAN(圖1)。GAN通過(guò)對(duì)抗的方式,同時(shí)訓(xùn)練生成器(generator)和判別器(discriminator),生成器用于生成假樣本,讓這個(gè)假樣本無(wú)限逼近真實(shí)樣本,判別器則需要盡量準(zhǔn)確地判斷輸入的是真實(shí)樣本還是由生成器自己生成的假樣本。
圖1 GAN結(jié)構(gòu)
GAN的主要結(jié)構(gòu)由一個(gè)生成模型G(generator)和一個(gè)判別模型D(discriminator)組成。輸入圖片之后,程序提取輸入的圖片,并采樣轉(zhuǎn)化成數(shù)據(jù)tensor,數(shù)據(jù)輸入到網(wǎng)絡(luò)中開始計(jì)算,然后生成器G和判別器D開始它們的零和最大最小博弈。簡(jiǎn)單來(lái)說(shuō),通過(guò)生成器,低分辨率的圖像可以重建一張高分辨率的圖像,然后由判別器網(wǎng)絡(luò)判斷。當(dāng)生成器網(wǎng)絡(luò)的生成圖能夠很好地“騙”過(guò)判別器網(wǎng)絡(luò),使判別器認(rèn)為這個(gè)生成圖是原數(shù)據(jù)集中的圖像,這里超分辨率重構(gòu)的網(wǎng)絡(luò)的目標(biāo)就達(dá)成了。生成器與判別器的工作原理如圖2所示,數(shù)據(jù)傳遞如圖3所示。
圖2 生成器與判別器的工作原理
圖3 生成器與判別器的數(shù)據(jù)傳遞
總體來(lái)說(shuō),在GAN中二者互相博弈,生成器不斷生成并輸出假的數(shù)據(jù),并與訓(xùn)練集一同輸入判別器中進(jìn)行判斷,繼續(xù)優(yōu)化學(xué)習(xí)。在這個(gè)過(guò)程中,生成器和判別器反復(fù)博弈,共同進(jìn)化,最終達(dá)到超進(jìn)化,經(jīng)過(guò)有限次迭代之后輸出數(shù)據(jù)并轉(zhuǎn)化為新的圖像輸出[3]。圖4是SRGAN的網(wǎng)絡(luò)結(jié)構(gòu),比較直觀的描述了GAN在解決圖像超分辨率的網(wǎng)絡(luò)運(yùn)行思路。
圖4 SRGAN的網(wǎng)絡(luò)結(jié)構(gòu)
GAN模型本質(zhì)上是一個(gè)最大最小博弈。目標(biāo)函數(shù)為
minGmaxDV(G,D)=Ex~pr(x)[logD(x)]+Ez~pr(z)[log(1-D(G(z)))],
(1)
其中,E代表期望,x~pr(x)代表x服從pr(x)分布,z是隨機(jī)噪聲,服從z~pr(z)的分布。而如何得出這個(gè)結(jié)論,就要關(guān)系到生成器和判別器的網(wǎng)絡(luò)原理。
2.2.1 判別器
判別器是程序需要優(yōu)先訓(xùn)練的模型,使它能夠判別一個(gè)輸入數(shù)據(jù)是否真的來(lái)自真實(shí)數(shù)據(jù)集,如果返回值大于0.5就為真,小于0.5則為假。可以看出,使用最簡(jiǎn)單的二分類就可實(shí)現(xiàn),這里使用交叉熵的方法[4]。
給定一個(gè)樣本(x,y),y∈{1,0},表示其來(lái)自生成器還是真實(shí)數(shù)據(jù)。對(duì)于輸入的x,判別器會(huì)返回一個(gè)y,y表示x屬于真實(shí)數(shù)據(jù)的概率,
P(y=1|x)=D(x),
(2)
反之,x屬于生成的圖像數(shù)據(jù)概率
P(y=0|x)=1-D(x)。
(3)
判別器的目的是最小化交叉熵,交叉熵的表達(dá)式是[5]
minD(-Ex~p(x)(ylogP(y=1|x))+(1-y)logP(y=0|x)),
(4)
帶入式(2)和式(3),得到
minD(-Ex~p(x)(yD(x)+(1-y)(1-D(x))))。
(5)
假設(shè)整個(gè)樣本數(shù)據(jù)里面真實(shí)圖像數(shù)據(jù)和生成器生成的圖像數(shù)據(jù)是等比例的,
(6)
得到
(7)
然后最小化最大化互換,同時(shí)把負(fù)號(hào)變?yōu)檎?hào),
maxDEx~pr(x)(D(x))+Ex~pg(x)(1-D(x))。
(8)
如果x~pg(x),代表x是生成器生成的,而生成器又是滿足z~p(z)分布而生成的,再次替換可得
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z))),
(9)
即所需求的目標(biāo)函數(shù)。
2.2.2 生成器
生成器是判別器訓(xùn)練完成后才開始訓(xùn)練的模型,作用是在給定輸入的情況下得到一定的輸出,然后繼續(xù)送給判別器判斷,之后返回給自身一個(gè)誤差值,從而繼續(xù)學(xué)習(xí)。
生成器的目標(biāo)剛好和判別器相反,即讓判別器把自己生成的樣本判別為真實(shí)樣本。因?yàn)镚AN網(wǎng)絡(luò)的本質(zhì)數(shù)學(xué)模型是一個(gè)最大最小博弈,通過(guò)判別器得到了目標(biāo)函數(shù),從而得到最大值max,所以生成器的目的就是得到最小值min[6]。目標(biāo)函數(shù)
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z)))
(10)
由兩部分構(gòu)成,由后一部分可得生成器目標(biāo)
minGEz~p(z)(1-D(G(z)))。
(11)
將生成器與判別器的函數(shù)結(jié)合,即得到生成對(duì)抗網(wǎng)絡(luò)的模型,
minGmaxDV(G,D)=Ex~pr(x)(logD(x))+Ez~p(z)(log(1-D(G(z))))。
(12)
訓(xùn)練時(shí)的優(yōu)化需要引入生成對(duì)抗網(wǎng)絡(luò)的損失函數(shù),
LossG=log(1-D(G(z)))or-log(D(G(z))),LossD=-log(D(x))or-log(1-D(G(z))),
(13)
LossG=log(1-D(G(z)))or-log(D(G(z)))。
(14)
由生成器的目標(biāo)式得
minGEz~p(z)(1-D(G(z)))。
(15)
后面一部分是原作者Ian Goodfellow提出的,效果等同于優(yōu)化前面那個(gè)而且梯度性質(zhì)更好。
LossD=-log(D(x))-log(1-D(G(z))),
(16)
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z)))。
(17)
2.4.1 判別器越好,生成器梯度消失越嚴(yán)重
在最優(yōu)判別器的條件下,最小化生成器的損失函數(shù)和最小化P1與P2之間的JS散度是等價(jià)的[7],
(18)
對(duì)于P1與P2來(lái)說(shuō)是完全對(duì)稱的,JS是兩個(gè)KL散度的疊加(KL散度又稱相對(duì)熵),一定是大于等于0的,所以JS散度一定大于等于0。在這里可能會(huì)出現(xiàn)嚴(yán)重的問(wèn)題:如果兩個(gè)分布沒(méi)有重疊的話,JS散度就為0,而在訓(xùn)練初期,兩個(gè)分布必然是基本不會(huì)重疊,所以假如在這里判別器被訓(xùn)練得過(guò)于好,損失函數(shù)就會(huì)經(jīng)常收斂到固定的-2 log 2,從而產(chǎn)生沒(méi)有梯度的情況。然后網(wǎng)絡(luò)就沒(méi)法繼續(xù)訓(xùn)練下去了,對(duì)抗網(wǎng)絡(luò)中的生成器和判別器是要一起進(jìn)化變強(qiáng)的,一個(gè)過(guò)于強(qiáng)將會(huì)導(dǎo)致另一個(gè)無(wú)法繼續(xù)訓(xùn)練[8]。
2.4.2 可能出現(xiàn)梯度不穩(wěn)定和模式崩潰
GAN采用的是對(duì)抗訓(xùn)練的方式,判別器的梯度更新來(lái)自判別器,生成一個(gè)樣本,交給判別器去評(píng)判,判別器會(huì)輸出生成的假樣本是真樣本的概率。生成器會(huì)根據(jù)這個(gè)反饋不斷改善。但假如有一次生成器生成的并不真實(shí),判別器卻出了問(wèn)題,給了正確評(píng)價(jià),或者在一次生成器生成的結(jié)果中存在某一些特征被判別器所認(rèn)可了,這時(shí)候生成器就會(huì)認(rèn)為這里的輸出反而是正確的,接下來(lái)繼續(xù)輸出相同的數(shù)據(jù)判別器就還會(huì)給出高的評(píng)分,最終就會(huì)導(dǎo)致生成結(jié)果中的一些重要信息或特征殘缺[9]。
首先需要生成器(G)生成圖片模型,判別器(D)判斷圖片是否為真,如圖5所示。
圖5 GAN網(wǎng)絡(luò)架構(gòu)
首先需要向生成器輸入一個(gè)噪聲,生成隨機(jī)數(shù)組,繼續(xù)輸出一個(gè)數(shù)據(jù)轉(zhuǎn)換為一張圖片,輸入圖片之后,經(jīng)過(guò)判別器來(lái)輸出是一個(gè)數(shù)1或者0,代表圖片是否是狗。
然后通過(guò)訓(xùn)練網(wǎng)絡(luò),把真圖與假圖拼接,打上不同的標(biāo)簽,真圖為1,假圖為0,送到網(wǎng)絡(luò)中訓(xùn)練。
3.2.1 數(shù)據(jù)輸入
聲明集合dataloader,將訓(xùn)練和測(cè)試數(shù)據(jù)都放入其中。
3.2.2 訓(xùn)練網(wǎng)絡(luò)
先重寫構(gòu)造函數(shù),構(gòu)造一個(gè)父類的函數(shù) “super”,然后定義網(wǎng)絡(luò)結(jié)構(gòu)block,運(yùn)用nn.sequential將多個(gè)函數(shù),如卷積函數(shù)Conv2d和激活函數(shù)PReLU,并列放置,經(jīng)過(guò)多個(gè)ResidualBlock殘差網(wǎng)絡(luò)模塊處理。采樣之后,進(jìn)入前向傳播forward函數(shù),最后經(jīng)過(guò)tanh函數(shù)映射到-1到1,最后得到一個(gè)0到1的數(shù)據(jù)輸出[10]。
判別器是一個(gè)二分類的模型,先重寫構(gòu)造函數(shù)構(gòu)造父類函數(shù),然后進(jìn)入多層的網(wǎng)絡(luò),在進(jìn)入一層池化層之后,取平均值下采樣,得到1×1的數(shù)據(jù),最后只得到batchsize的數(shù)據(jù),然后通過(guò)sigmoid函數(shù)將實(shí)數(shù)域映射到0~1,即batchsize的概率,符合判別器二分類概率的原理[11]。
通過(guò)優(yōu)化器進(jìn)行判別器的訓(xùn)練。首先為了優(yōu)化判別器,將其梯度歸零,然后規(guī)定判斷真實(shí)圖片和虛假圖片的概率,接著規(guī)定判別器的損失函數(shù),計(jì)算出d_loss,然后執(zhí)行上面的步驟。
訓(xùn)練生成器時(shí),將生成器的梯度置零后,生成一個(gè)假的圖片,輸入判別器,得出判別器判斷為假的概率,輸入給生成器的損失函數(shù),計(jì)算得出g_loss,再反向傳播backward,最終運(yùn)行開始訓(xùn)練。
完整的網(wǎng)絡(luò)架構(gòu)中日志記錄以及數(shù)據(jù)輸入輸出可視化不再贅述,可將生成模型記錄保存在字典文件pth之中,以供之后的測(cè)試或者訓(xùn)練使用。
完成了GAN構(gòu)造并經(jīng)過(guò)訓(xùn)練之后,進(jìn)行網(wǎng)絡(luò)性能測(cè)試。筆者下載了超分辨率重構(gòu)的數(shù)據(jù)集,包含×4和×8的每個(gè)大約3 000張圖片的測(cè)試用數(shù)據(jù)集,數(shù)據(jù)集文件列表如圖6所示。
圖6 超分辨率重構(gòu)數(shù)據(jù)集
因?yàn)樯窠?jīng)網(wǎng)絡(luò)訓(xùn)練運(yùn)算量巨大,且需要占用大量?jī)?nèi)存,所以這里將其放到訓(xùn)練試驗(yàn)機(jī)上,運(yùn)用4塊RTX 3090顯卡進(jìn)行訓(xùn)練。訓(xùn)練整體大概1 000個(gè)迭代epoch,最終得到兩個(gè)記錄模型權(quán)重的pth文件,這兩個(gè)權(quán)重文件可以直接輸入測(cè)試網(wǎng)絡(luò),以下通過(guò)幾個(gè)測(cè)試圖片檢測(cè)訓(xùn)練的結(jié)果。
測(cè)試所用的一組原圖Ground truth,如圖7所示?!?超分的測(cè)試結(jié)果如圖8所示?!?超分的測(cè)試結(jié)果如圖9所示。可以看出,在×8的超分上,如果細(xì)節(jié)比較小的話,得出的超分圖會(huì)比較邊緣性的模糊,×4的超分結(jié)果已經(jīng)比較理想。
圖7 原圖
圖8 ×4測(cè)試結(jié)果
圖9 ×8測(cè)試結(jié)果
整體來(lái)說(shuō),網(wǎng)絡(luò)訓(xùn)練結(jié)果比較理想,成功收斂且沒(méi)有出現(xiàn)梯度消失以及模式崩潰的情況。說(shuō)明利用深度學(xué)習(xí)的神經(jīng)網(wǎng)絡(luò)中的GAN生成對(duì)抗網(wǎng)絡(luò),能夠?qū)崿F(xiàn)圖像超分辨率的目標(biāo)。