王憲保,肖本督,姚明海
(浙江工業(yè)大學 信息工程學院,杭州 310023)
機器學習是從數(shù)據(jù)中學習并挖掘未知規(guī)律,然后利用規(guī)律對樣本進行分析與預測.因此數(shù)據(jù)的收集尤為重要,是機器學習中最為重要的步驟之一.在如今的實際應用場景中,最為常見的數(shù)據(jù)集由少量標簽數(shù)據(jù)與大量無標簽數(shù)據(jù)組成.監(jiān)督學習的標記數(shù)據(jù)太少不利于模型訓練,而無監(jiān)督學習則僅利用了無標記數(shù)據(jù)而浪費了標記數(shù)據(jù)上提供的信息.半監(jiān)督學習(SSL,Semi-supervised Learning)是介于監(jiān)督學習和無監(jiān)督學習之間的機器學習方法,它利用數(shù)據(jù)集中的大量無標記樣本提供相應信息來幫助標記數(shù)據(jù)協(xié)同訓練模型,是近年來的研究熱點之一.研究者們在SSL上進行了許多深入的研究并提出了許多算法.
SSL算法分為半監(jiān)督分類、半監(jiān)督回歸、半監(jiān)督聚類等等.本文提出的算法是基于圖的半監(jiān)督分類,其核心理念是利用無標記圖像來協(xié)助訓練標記圖像,然后獲得比只有少量標記圖像訓練得到的分類器性能更優(yōu)的分類器.
監(jiān)督學習模型的訓練必須使用人工給整個數(shù)據(jù)集標記,并且這些算法對數(shù)據(jù)集的數(shù)量有著嚴格的要求.經(jīng)過長時間的發(fā)展監(jiān)督學習已經(jīng)擁有許多網(wǎng)絡結(jié)構(gòu)合理的成熟算法,其中卷積神經(jīng)網(wǎng)絡(CNN,Convolutional Neural Network)最為典型.因此本文將CNN作為算法骨干網(wǎng)絡的網(wǎng)絡結(jié)構(gòu).
花費可以接受的人工開銷可以獲得遠多于標記數(shù)據(jù)的無標記圖像數(shù)據(jù)集,然而對模型網(wǎng)絡來說,這些無標記數(shù)據(jù)提供的信息是無序的、巨大的、無法確定的,如何處理這些無標記圖像成為一個難點.保有一定含量的背景信息可以幫助模型在更換場景時保證模型識別準確率,起到增強模型魯棒性的作用.但過多的背景信息會減弱模型對目標特征的提取,過少的背景突出目標又將導致模型過擬合,因此目標和背景信息的比例需要控制在一個范圍內(nèi).本文提出目標初定位網(wǎng)絡可以通過設定閾值控制目標和背景信息的比例從而增強模型性能.
Merz等人在1992年的文獻[1]中提出了SSL的概念.2003年,Zhou等人在文獻[2]中提出了一種基于平滑(Smooth)理念的SSL算法,其核心思想是設計的算法對于數(shù)據(jù)集里標記數(shù)據(jù)與無標記數(shù)據(jù)共同揭示的內(nèi)在結(jié)構(gòu)應該足夠平滑.這篇文章對于半監(jiān)督學習的貢獻是卓越的,后來的研究者們基于這種思想將專注點從模型本身轉(zhuǎn)移到數(shù)據(jù)集的內(nèi)部結(jié)構(gòu).
Fang等人[3]提出了一種非負低秩表示(NNLRR,Non-Negative Low-Rank Representation)的魯棒半監(jiān)督子空間圖聚類算法.這個方法在相似矩陣(Affinity Matrix)構(gòu)建過程中明確引入了標記數(shù)據(jù)的標記信息幫助其構(gòu)建.之后Liu等人[4]將非負低秩表示、流形平滑度與標簽適應度整合到一個框架中,提出了名為基于非負低秩表示的流形嵌入半監(jiān)督學習的算法.這些研究都在很大程度上降低了以往的算法中標記數(shù)據(jù)上標記傳遞信息的失真度,但是設計合理的數(shù)據(jù)集內(nèi)部結(jié)構(gòu)的做法對于半監(jiān)督學習算法的提升是有限度的.
研究者們在其他方面進行了深入的研究,Tarvainen與Valpola在文獻[5]中提出了均值教師模型(MTM,Mean Teacher Model),其主要做法是引入指數(shù)移動平均(EMA,Exponential Moving Average)的概念,教師模型的網(wǎng)絡結(jié)構(gòu)與學生模型的相同.不同的是教師模型為學生模型過去所有訓練效果的平均值,即時間上的平均,因此訓練過程中偶爾出現(xiàn)的波動對訓練效果的負面影響被大大減小,提升了模型訓練效果.對學生模型進行正常訓練,然后取學生模型權(quán)重的指數(shù)移動平均作為教師模型的權(quán)重.圖像分別輸入學生模型和教師模型獲得模型預測的類概率向量,取這兩個類概率模型的一致性差值作為模型訓練的評判標準.類似的做法還有文獻[6]中Verma等人提出的一致性插值訓練(ICT,Interpolation Consistency Training)的半監(jiān)督學習算法,它鼓勵無標記點的插值處的預測與這些點處的預測值一致,是通過模型輸出的結(jié)果調(diào)整模型訓練的過程.
同時半監(jiān)督學習的理念可以和許多領域相結(jié)合解決實際問題.例如,楊緒兵等人在文獻[7]中將半監(jiān)督學習和矩陣模式相結(jié)合.王格格等人在文獻[8]中將生成對抗網(wǎng)絡擴展到半監(jiān)督學習.而付曉等人[9]設計了一種編碼器用以提取圖像特征并結(jié)合半監(jiān)督特殊數(shù)據(jù)集構(gòu)成來突出模型的分類能力.
CNN是監(jiān)督學習中的經(jīng)典網(wǎng)絡.其歷史上進行突破的階段是AlexNet[10]的出現(xiàn),之后出現(xiàn)了VGGNet[11]、ResNet[12]、R-CNN[13]等一系列延伸網(wǎng)絡及其變體.
然而這些網(wǎng)絡僅讓人們知道這種網(wǎng)絡結(jié)構(gòu)對目標的識別更加準確,但至于目標什么特征的權(quán)重比重較大就不得而知.類激活映射(CAM,Class Activation Mapping)的出現(xiàn)使得人們可以更加直觀的理解網(wǎng)絡,人們可以借助CAM來調(diào)整網(wǎng)絡進而提高網(wǎng)絡的識別率.
根據(jù)文獻[14],生成CAM的做法是在最終輸出層之前,加入全局平均池化層,獲得特征圖對應的特征權(quán)重向量W,經(jīng)過W加權(quán)的特征圖向量疊加輸出類激活映射.其流程如圖1所示.
圖1 類激活映射生成示意圖
近年來研究者們受到這種可視化理解網(wǎng)絡模型的做法的啟發(fā)進行深入的研究,使得CAM在許多領域上得到應用.Chattopadhay等人在文獻[15]中對CAM進行改進使得提出的算法可以在單張圖片上解釋并分割多個目標實例.Li等人[16]設計了一種可解釋的輕量級分類器SCOUTER.SCOUTER的可解釋性的實現(xiàn)是根據(jù)CAM的原理實現(xiàn)的,文章中具體做法是分類器會在網(wǎng)絡高度關注的特征區(qū)域給出其對于所有類別肯定或者是否定的評分.最終將這些評價疊加得到圖像屬于某一類別的置信度.
表1為文中所使用的參數(shù)變量對應表.
表1 參數(shù)對應表
在本節(jié)中主要介紹結(jié)合CAM的目標初定位掩膜生成網(wǎng)絡的機制,達到調(diào)整目標與背景信息之間比例關系的目的.
CNN網(wǎng)絡的主要結(jié)構(gòu)由輸入層、卷積層、池化層與輸出層構(gòu)成.根據(jù)2.2節(jié)介紹的CAM原理,將CNN網(wǎng)絡轉(zhuǎn)化為具有CAM功能的網(wǎng)絡.
本文選擇全局平均池化(GAP,Global Average Pooling)而不是全局最大池化(GMP,Global Max Pooling),原因是算法要求目標初定位網(wǎng)絡得到最大可能的區(qū)分目標類別的特征區(qū)域.而GMP只能輸出目標最高辨識度的區(qū)域而完全拋棄了低辨識度的特征區(qū)域.
下面以VGGNet-16為例搭建結(jié)合CAM的目標初定位網(wǎng)絡.首先將圖片尺寸調(diào)整為寬高都為224像素輸入VGGNet-16中,圖片在網(wǎng)絡中流轉(zhuǎn)到卷積層,該層輸出的尺寸為[7,7,512],這個輸出也被稱為特征圖向量.使fk(w,h)表示圖像任意位置(w,h)在特征向量圖中內(nèi)核單元k的激活響應,其中k表示特征圖向量中第k個[7,7]的特征圖.將fk(w,h)輸入全局平均池化層,得到其輸出:
(1)
其中(w0,h0)表示圖像左上角坐標,(w0+wI,h0+hI)表示圖像右下角坐標.wI為圖像的寬度,hI為圖像的高度.
對于帶有類別c標簽的圖像,CAM可用公式(2)計算得到:
(2)
將公式(1)代入公式(2)中,得到公式(3):
(3)
圖像預測為c類時,圖像中任意坐標位置的CAM值由公式(4)計算得出:
(4)
結(jié)合公式(3)和公式(4)可以看出,CAM是在圖像中所有像素位置上計算Pc的值,即CNN網(wǎng)絡對目標類別判定的依據(jù).
將Pc投影到[0,255]數(shù)值范圍的RGB空間中,得到Ic.通過公式(5)將Ic與原圖像Iori疊加得到最終的CAM的熱度圖Ih(如圖2(a)所示).
圖2 結(jié)合CAM的目標初定位生成示意圖
(5)
根據(jù)公式(6)計算得到目標初定位網(wǎng)絡的一個輸出Imask,即初定位掩膜(見圖2(b)).
(6)
將Sc作為公式(7)的輸入,得到目標初定位掩膜的質(zhì)量評分Smask.
(7)
根據(jù)公式(8)所示將原圖像、目標初定位掩膜和掩膜評分生成如圖2(c)所示目標初定位圖Iout.
Iout(p)=Iori(p)Imask(p)Smask,p=(w,h)
(8)
在本節(jié)中介紹了強化MTM的搭建.MTM是將學生模型權(quán)重θ輸入EMA計算得到教師模型權(quán)重θ′,如公式(9)所示:
(9)
在研究中對比發(fā)現(xiàn)α的數(shù)值為0.99或0.999時,模型的訓練效果會好于其他數(shù)值.
因此更新教師模型權(quán)重時,比重較大的是上一時刻的教師模型權(quán)重.當前時刻的學生模型的權(quán)重比重較小,但其影響的值將加入教師模型的權(quán)重中.換而言之,教師模型的權(quán)重是過去所有時刻的學生模型權(quán)重與相對應的權(quán)重的累加.并且隨著時間的推移,時間上距離現(xiàn)在越遠的學生模型權(quán)重將會被不斷懲罰,減小自身在教師模型權(quán)重中的占比.
選取均方誤差(MSE,Mean-square Error)作為模型的評估函數(shù),得到目標優(yōu)化函數(shù),也稱為一致性差值函數(shù)J(X,θ,θ′):
J(X,θ,θ′)=‖T(X,θ′)-S(X,θ)‖2
(10)
本文提出一種結(jié)合類激活映射的目標初定位掩膜的增強MTM算法(MC-MTM),算法的結(jié)構(gòu)示意圖如圖3所示.
圖3 MC-MTM算法結(jié)構(gòu)
MC-MTM分為目標初定位網(wǎng)絡和MTM模型訓練網(wǎng)絡.首先原始數(shù)據(jù)流入目標初定位網(wǎng)絡,它的骨干網(wǎng)絡是在ImageNet數(shù)據(jù)集上預先訓練過的CNN網(wǎng)絡.數(shù)據(jù)集由標記數(shù)據(jù)集與無標記數(shù)據(jù)集構(gòu)成,即X=xL+xU.xU輸入訓練后的模型,得到兩個輸出:Smask和Imask,然后Imask、Iori與對應的Smask生成Iout.
將處理后的xU記為xPL.將xPL與xL混合生成新的數(shù)據(jù)集,記為Xn.如圖3所示,網(wǎng)絡在處理標記數(shù)據(jù)與無標記數(shù)據(jù)的方式是不同的.當數(shù)據(jù)為標記數(shù)據(jù)時,數(shù)據(jù)輸入學生模型進行正常的迭代訓練,得到學生模型預測的類概率向量PS,并且直接輸出標記類的概率值為1,其他類概率值都為0的類概率向量P.然后通過交叉熵損失函數(shù)驗證模型預測的PS與真值P之間的相似性;而數(shù)據(jù)為無標記數(shù)據(jù)時,得到PS的步驟不變,但需將圖片輸入教師模型獲得其預測的類概率向量PT.計算PT與PS之間的一致性差值損失,即通過一致性差值函數(shù)J(X,θ,θ′)來判斷模型此次的訓練效果.最終將交叉熵損失函數(shù)與一致性差值函數(shù)相疊加生成算法最終的優(yōu)化函數(shù).
實驗硬件平臺為NVIDIA RTX 2060 SUPER 8G,軟件平臺為CUDA 10.2、PyTorch 1.7.0和Python 3.7.所有實驗結(jié)果均在上述實驗環(huán)境中獲得.實驗選取了MTM作為對比的基準模型,在標記數(shù)據(jù)占比、骨干網(wǎng)絡模型、數(shù)據(jù)集等維度進行對比.多分類的評判標準是Top1、Top5,二分類的則是Top1.所有實驗的實驗結(jié)果均為多次實驗的平均結(jié)果.
在Top1的評判標準下,要求網(wǎng)絡判別的結(jié)果必須是真實結(jié)果.而在Top5的評判標準下,在網(wǎng)絡識別結(jié)果中取可能性最高的5類,只需這5類中包含正確類別,則判別為網(wǎng)絡正確識別圖像.
本文選取兩個數(shù)據(jù)集進行一系列的對比實驗.下面的篇幅對兩個數(shù)據(jù)集分別做詳細介紹.
4.1.1 貓狗大戰(zhàn)(Dogs vs.Cats)數(shù)據(jù)集(1)https://www.kaggle.com/c/dogs-vs-cats
貓狗大戰(zhàn)是Kaggle大數(shù)據(jù)競賽2013年的賽題,是一個經(jīng)典的二分類問題.貓狗大戰(zhàn)數(shù)據(jù)集由訓練集與測試集構(gòu)成,訓練集全部是標記數(shù)據(jù),其中含有貓的圖像12.5K、狗的圖像12.5K.測試集是貓狗混合無標記圖像12.5K,由于訓練集中全部是標記數(shù)據(jù),需要自行制作訓練數(shù)據(jù)集以滿足SSL對數(shù)據(jù)集的要求.具體做法為屏蔽訓練集中某一部分的標簽作為訓練時的無標記數(shù)據(jù)集,剩下的圖片為標記數(shù)據(jù)集.
4.1.2 動物(Animals-10)數(shù)據(jù)集(2)https://www.kaggle.com/alessiocorrado99/animals10
動物數(shù)據(jù)集是Kaggle網(wǎng)站上一位網(wǎng)友分享的數(shù)據(jù)集,包含大約10種類別的28K的動物圖像,分別是:狗,貓,馬,蜘蛛,蝴蝶,雞,綿羊,牛,松鼠,大象.每個類別的圖像計數(shù)范圍從2K~5K不等.與貓狗大戰(zhàn)數(shù)據(jù)集不同,數(shù)據(jù)集的主目錄按類別分為多個文件夾,圖像文件夾的名字記為圖像類別,制作訓練數(shù)據(jù)集時從每個文件夾中移出相同比例圖像記為訓練集中的無標記圖像,文件夾中剩余的圖像則是訓練集中的標記圖像.
在本次實驗中,MC-MTM的骨干網(wǎng)絡采用ResNet-50,實驗的數(shù)據(jù)集則采用Dogs vs.Cats數(shù)據(jù)集.設置3組標記數(shù)據(jù)占比實驗,標記數(shù)據(jù)占比是指標記數(shù)據(jù)在訓練數(shù)據(jù)中的占比.將標記數(shù)據(jù)占比分別設置為4%,8%,16%,其他訓練超參數(shù)設置為表2中羅列的數(shù)值.
表2 超參數(shù)表
得到如圖4所示的實驗結(jié)果.圖4表示Top1準確率隨著MC-MTM訓練迭代代數(shù)變化的結(jié)果曲線.圖中實線表示的是8%標記數(shù)據(jù)占比的實驗結(jié)果,虛線表示的是16%標記數(shù)據(jù)占比的實驗結(jié)果,點劃線則表示4%標記數(shù)據(jù)占比的實驗結(jié)果.
圖4中隨著迭代次數(shù)的增加,點劃線呈現(xiàn)快速上升而后緩慢下降最后收斂的趨勢.而實線和虛線總體走勢和點劃線完全不同,呈現(xiàn)先快速升高而后緩慢上升最后收斂的趨勢.最終得到表3中的實驗結(jié)果數(shù)據(jù).
圖4 不同標記數(shù)據(jù)占比的數(shù)據(jù)集在MC-MTM的實驗結(jié)果
表3 標簽數(shù)據(jù)占比實驗數(shù)據(jù)
分析實驗結(jié)果得到在MC-MTM中對標記數(shù)據(jù)具有一定數(shù)量要求的結(jié)論.標記數(shù)據(jù)量過少,數(shù)據(jù)集中的無標記數(shù)據(jù)不僅不能夠幫助標記數(shù)據(jù)訓練網(wǎng)絡反而會使得模型的性能惡化.4%標記數(shù)據(jù)占比實驗證明了這個論點,模型Top1準確率從峰值的61.82%不斷降低到49.55%收斂,在140代~180代之間準確率甚至出現(xiàn)急劇下降.
而在8%標記占比和16%標記占比的實驗結(jié)果中,無標記數(shù)據(jù)提供的信息幫助標記數(shù)據(jù)提升了模型的性能.16%標記占比的實驗相較于8%標記占比的模型的性能提升了1.44%.因此標記數(shù)據(jù)的數(shù)量增長可以一定程度提升MC-MTM性能.
在實際應用中需要權(quán)衡標記數(shù)據(jù)占比的選擇.面對大型數(shù)據(jù)集時,標記數(shù)據(jù)占比增長1%所耗費的人工與時間成本十分巨大.且必須防止標記數(shù)據(jù)不足引發(fā)的模型訓練效果不佳.16%的標記數(shù)據(jù)占比相較于8%的需要多付出一倍的時間或人工成本,但模型性能僅有小幅提升.因此本文的實驗選擇表3中數(shù)據(jù)為黑體的8%標簽數(shù)據(jù)占比作為后續(xù)實驗的超參數(shù).
本次實驗設置不同模型在兩個數(shù)據(jù)集上的比較,模型的選擇有本文所提出的MC-MTM和基準模型MTM.按照4.2節(jié)將標記數(shù)據(jù)占比設為8%,其他的訓練超參數(shù)設置為表2中所示的數(shù)值.
最終得到如圖5所示的實驗結(jié)果.圖5(a)和圖5(b)分別展示了兩個網(wǎng)絡在貓狗大戰(zhàn)數(shù)據(jù)集和動物數(shù)據(jù)集上的最終結(jié)果.
圖5(a)中有兩種類型的曲線:表示MC-MTM的實線和表示MTM的虛線,其含義為模型的Top1準確率.
圖5(b)與圖5(a)不同的是模型在多類數(shù)據(jù)集上的評判標準中加入了Top5.曲線類型的含義是相同的,實線代表MC-MTM,虛線代表MTM.每一類曲線都分為無標記曲線和“×”形標記曲線,其含義分別為模型Top1準確率和Top5準確率.
圖5 MC-MTM和MTM在貓狗大戰(zhàn)和動物數(shù)據(jù)集上的實驗結(jié)果
整理所有數(shù)據(jù)得到表4所示的實驗結(jié)果數(shù)據(jù),在表4中粗體數(shù)據(jù)是本文提出的算法的最終結(jié)果數(shù)據(jù).相較于基準模型MTM,在貓狗大戰(zhàn)數(shù)據(jù)集上MC-MTM的Top1領先了1.73%,而在動物數(shù)據(jù)集上Top1和Top5分別領先了4.48%和1.75%.因此初定位網(wǎng)絡對目標對象進行初定位調(diào)整目標對象和背景之間的比例關系到合適大小,可以有效增強MTM網(wǎng)絡的性能.
表4 MC-MTM和MTM在貓狗大戰(zhàn)和動物數(shù)據(jù)集上的實驗數(shù)據(jù)
本次實驗中選取了網(wǎng)絡層數(shù)較多、結(jié)構(gòu)較為復雜的ResNet-50和相對層數(shù)較少、結(jié)構(gòu)簡單的VGGNet-16分別作為MC-MTM和MTM的骨干網(wǎng)絡,觀察MC-MTM在不同CNN網(wǎng)絡作為骨干網(wǎng)絡的條件下其性能相對于MTM是否增長.實驗設置貓狗大戰(zhàn)數(shù)據(jù)集,標記數(shù)據(jù)占比為8%,其他訓練超參數(shù)為表2中的數(shù)值,得到如圖6所示的實驗結(jié)果.
圖6(a)與圖6(b)分別代表了ResNet-50、VGGNet-16作為MTM和MC-MTM骨干網(wǎng)絡的實驗結(jié)果.兩幅圖中有兩種類型的曲線:實線和虛線,其含義分別為MC-MTM模型的Top1準確率和MTM模型的Top1準確率.最終得到表5中的實驗結(jié)果數(shù)據(jù).
圖6 多種CNN骨干網(wǎng)絡的MTM與MC-MTM在貓狗大戰(zhàn)數(shù)據(jù)集上的結(jié)果
表5中的黑體數(shù)據(jù)為MC-MTM的最終實驗結(jié)果數(shù)據(jù).搭載ResNet-50骨干網(wǎng)絡的MC-MTM的性能比MTM的提升了1.73%.而搭載了VGGNet-16骨干網(wǎng)絡的MC-MTM的性能比MTM的提升了1.38%.
表5 多種CNN骨干網(wǎng)絡的MTM與MC-MTM在貓狗大戰(zhàn)上的實驗數(shù)據(jù)
從簡單結(jié)構(gòu)的CNN網(wǎng)絡到復雜結(jié)構(gòu)的CNN網(wǎng)絡,MC-MTM相較于MTM都可以有效的增強模型性能.
本文提出的MC-MTM的核心結(jié)構(gòu)是結(jié)合了類激活映射的初定位網(wǎng)絡,起到了調(diào)整目標與背景之間的比例關系并且可以過濾一些冗余的背景信息和其他無關類別的干擾信息的作用.將目標主要特征過濾或者背景信息剩余太多的少量樣本則被認為是訓練的負樣本,可以增強模型的魯棒性.基于CNN的模型可以嵌入MC-MTM中,監(jiān)督學習方法在半監(jiān)督學習中得以應用.
本文設置了MC-MTM與MTM在Dogs vs.Cats和Animals-10數(shù)據(jù)集上的性能對比實驗,結(jié)果表明,MC-MTM在Top1和Top5上已經(jīng)明顯優(yōu)于MTM.并且在另一組實驗中置換了CNN骨干網(wǎng)絡進行實驗,MC-MTM在此實驗中的表現(xiàn)也都優(yōu)于MTM.因此可以證明本文提出的算法的有效性和可行性.在標簽數(shù)據(jù)占比實驗中得到MC-MTM算法對標簽數(shù)據(jù)的數(shù)量也有一定的要求的結(jié)論.并且在標簽數(shù)據(jù)占比已經(jīng)滿足了算法最低要求的基礎上,增加標簽數(shù)據(jù)對模型性能的增強是有利的.但面對大型數(shù)據(jù)集時,需要考慮標記數(shù)據(jù)花費的人力和時間成本的因素來確定訓練數(shù)據(jù)集中合適的標簽數(shù)據(jù)占比.