邵偉志,潘麗麗,雷前慧,黃詩(shī)祺,馬駿勇
(中南林業(yè)科技大學(xué) 計(jì)算機(jī)與信息工程學(xué)院 湖南 長(zhǎng)沙 410004)
在深度學(xué)習(xí)中,大量標(biāo)記數(shù)據(jù)對(duì)于神經(jīng)網(wǎng)絡(luò)的訓(xùn)練是至關(guān)重要的。但是對(duì)于許多深度學(xué)習(xí)的任務(wù)而言,獲取這些標(biāo)記數(shù)據(jù)是較為困難的,比如醫(yī)療任務(wù)中每一個(gè)標(biāo)記都需要從專(zhuān)家的結(jié)論中得出。此外,通過(guò)網(wǎng)絡(luò)獲取的信息很大一部分是較為私密的,怎樣標(biāo)記這些數(shù)據(jù)也是一個(gè)復(fù)雜的問(wèn)題。半監(jiān)督學(xué)習(xí)[1]通過(guò)讓模型從未標(biāo)記數(shù)據(jù)中獲取信息來(lái)減少對(duì)于標(biāo)記數(shù)據(jù)的依賴,對(duì)于圖像搜索、文本分類(lèi)、文檔檢索[2]等任務(wù),半監(jiān)督學(xué)習(xí)都能取得很好的結(jié)果。近年來(lái),半監(jiān)督學(xué)習(xí)方法聚焦于在損失函數(shù)中增加損失項(xiàng),這些損失項(xiàng)一般都是通過(guò)未標(biāo)記數(shù)據(jù)取得的,促使模型更好地利用未標(biāo)記數(shù)據(jù)中的信息來(lái)對(duì)數(shù)據(jù)進(jìn)行分類(lèi)。半監(jiān)督學(xué)習(xí)方法可以大致分為熵最小化[3]、一致性正則化[4]與傳統(tǒng)正則化[5]三類(lèi),但是已有方法往往很容易忽視數(shù)據(jù)互補(bǔ)信息與多階段模型共同作用的優(yōu)勢(shì)。在上述研究基礎(chǔ)上,本文提出一種新的半監(jiān)督算法Mean Mixup,重新考慮模型生成偽標(biāo)簽的方法,通過(guò)多種數(shù)據(jù)增強(qiáng)方法使得模型能夠?qū)W習(xí)數(shù)據(jù)的互補(bǔ)信息,并且讓不同階段的模型共同作用,最終讓模型產(chǎn)生低熵預(yù)測(cè)來(lái)獲取更準(zhǔn)確的偽標(biāo)簽。為更好地應(yīng)用一致性正則化,對(duì)標(biāo)記數(shù)據(jù)和未標(biāo)記數(shù)據(jù)進(jìn)行混洗之后,根據(jù)數(shù)據(jù)類(lèi)型傳入不同的一致性損失函數(shù),并使用比重系數(shù)調(diào)節(jié)來(lái)讓模型能更好地從一致性正則化中受益。在常用數(shù)據(jù)集SVHN和CIFAR10上的實(shí)驗(yàn)結(jié)果驗(yàn)證了新算法的有效性,其在分類(lèi)準(zhǔn)確率上優(yōu)于Pseudo-label、Π-model等半監(jiān)督學(xué)習(xí)算法。
在深度學(xué)習(xí)中,聚類(lèi)假設(shè)指出,模型的決策邊界最好不通過(guò)邊緣數(shù)據(jù)分布的高密度區(qū)域,也就是使得模型輸出分布的熵盡可能小,這樣會(huì)使模型獲得更好的泛化性[6]。體現(xiàn)在學(xué)習(xí)過(guò)程中的熵最小化是讓模型對(duì)目標(biāo)數(shù)據(jù)的分類(lèi)結(jié)果盡量自信,使得模型的決策邊界盡量遠(yuǎn)離邊緣數(shù)據(jù)點(diǎn),同時(shí)讓模型的擬合曲線更貼合數(shù)據(jù)的邊緣分布。圖1對(duì)比了雙月系統(tǒng)中原始決策邊界與熵最小化約束下的決策邊界。半監(jiān)督學(xué)習(xí)算法中經(jīng)常通過(guò)添加損失項(xiàng)來(lái)使得模型在未標(biāo)記數(shù)據(jù)的概率分布實(shí)現(xiàn)熵最小化。Pseudo-label算法[3]對(duì)未標(biāo)記數(shù)據(jù)進(jìn)行預(yù)測(cè),利用熵最小化獲得置信度高的預(yù)測(cè)分布作為偽標(biāo)簽,并將其用作標(biāo)準(zhǔn)交叉熵?fù)p失的訓(xùn)練目標(biāo)[2]。本文提出的Mean Mixup算法類(lèi)似于Pseudo-label算法,都對(duì)未標(biāo)記數(shù)據(jù)構(gòu)建偽標(biāo)簽,不同之處在于Mean Mixup算法對(duì)如何得到偽標(biāo)簽進(jìn)行了新的設(shè)計(jì)。
圖1 原始決策邊界和熵最小化約束下的決策邊界對(duì)比
一致性是指模型對(duì)受到擾動(dòng)的數(shù)據(jù)點(diǎn)應(yīng)輸出相同的分布預(yù)測(cè),半監(jiān)督學(xué)習(xí)算法的很多突破性進(jìn)展都是在一致性正則化基礎(chǔ)上取得的。Π-model算法[4]通過(guò)在隨機(jī)模型fθ(x)對(duì)同一樣本的預(yù)測(cè)之間施加約束來(lái)實(shí)現(xiàn)一致性正則化。VAT算法[7]直接對(duì)輸入x增加擾動(dòng),并且這種擾動(dòng)能使得預(yù)測(cè)產(chǎn)生最大偏移,在受擾動(dòng)樣本與未受擾動(dòng)樣本產(chǎn)生的輸出分布之間施加一致性約束。Mean teacher算法[8]通過(guò)構(gòu)建教師-學(xué)生框架來(lái)實(shí)現(xiàn)一致性約束,在這個(gè)框架中使用了兩個(gè)結(jié)構(gòu)一致的網(wǎng)絡(luò),且教師網(wǎng)絡(luò)的參數(shù)是學(xué)生網(wǎng)絡(luò)參數(shù)的指數(shù)移動(dòng)平均值,為更加直觀,在本文中分別稱為原型網(wǎng)絡(luò)與指數(shù)網(wǎng)絡(luò)。指數(shù)網(wǎng)絡(luò)輸入樣本是對(duì)原型網(wǎng)絡(luò)輸入樣本的加噪值,在兩個(gè)網(wǎng)絡(luò)的預(yù)測(cè)分布之間通過(guò)應(yīng)用KL散度或者交叉熵函數(shù)來(lái)施加一致性約束。一致性正則化對(duì)于深度學(xué)習(xí),尤其是半監(jiān)督學(xué)習(xí)而言很有幫助,使得模型能夠從標(biāo)記數(shù)據(jù)的標(biāo)簽信息之外得到更多的高維特征信息。
Lloss=Lx+λuLu+λcLcon。
(1)
由于確認(rèn)偏差[9]的原因,直接對(duì)未標(biāo)記數(shù)據(jù)生成偽標(biāo)簽很容易對(duì)錯(cuò)誤標(biāo)簽過(guò)度自信,從而不會(huì)繼續(xù)從未標(biāo)記數(shù)據(jù)中進(jìn)行學(xué)習(xí)。Mean Mixup算法使用在同一模型的多個(gè)變種的共同作用下生成偽標(biāo)簽的方法,使偽標(biāo)簽的生成能獲得多個(gè)角度的互補(bǔ)信息,對(duì)于未標(biāo)記數(shù)據(jù)的判斷更加可靠。為了得到未標(biāo)記數(shù)據(jù)的軟標(biāo)簽[10],使得未標(biāo)記數(shù)據(jù)可以隨著網(wǎng)絡(luò)學(xué)習(xí)不斷更新偽標(biāo)簽的生成結(jié)果,這種新的偽標(biāo)簽獲取方法保證了網(wǎng)絡(luò)能夠從不同的角度和時(shí)間段受益,并逐漸提升偽標(biāo)簽的準(zhǔn)確度。偽標(biāo)簽猜測(cè)流程如圖2所示。
圖2 偽標(biāo)簽猜測(cè)流程
(2)
為了使網(wǎng)絡(luò)可以獲得更多不同角度的信息[11],使用不同時(shí)間段的網(wǎng)絡(luò)對(duì)未標(biāo)記數(shù)據(jù)進(jìn)行預(yù)測(cè)。為了專(zhuān)注于獲取原樣本的特征信息,僅將原未標(biāo)記數(shù)據(jù)傳入其他時(shí)間段的網(wǎng)絡(luò),并進(jìn)行如下計(jì)算:
(3)
(4)
其中:T是超參數(shù),T→0時(shí)銳化函數(shù)的輸出趨近“one-hot”編碼。文獻(xiàn)[12]中指出,降低“溫度(T)”有利于模型產(chǎn)生低熵預(yù)測(cè)。在運(yùn)行過(guò)程中,算法對(duì)每一個(gè)批次的未標(biāo)記數(shù)據(jù)都執(zhí)行以上的方法計(jì)算偽標(biāo)簽,這種構(gòu)建偽標(biāo)簽的方法使得偽標(biāo)簽的準(zhǔn)確度隨著模型學(xué)習(xí)不斷提升。
一致性約束會(huì)使模型擁有更好的抗干擾能力,以往一致性約束通常是添加在網(wǎng)絡(luò)的預(yù)測(cè)分布之間,區(qū)別是對(duì)輸入樣本加噪或者是網(wǎng)絡(luò)參數(shù)的變化。但在很多應(yīng)用中,標(biāo)記數(shù)據(jù)與未標(biāo)記數(shù)據(jù)經(jīng)常出現(xiàn)分布不匹配,甚至某一類(lèi)的標(biāo)記樣本數(shù)極少,模型難以獲取足夠的信息。對(duì)數(shù)據(jù)使用Mixup進(jìn)行混洗來(lái)彌補(bǔ)兩類(lèi)數(shù)據(jù)之間的差異,使得模型學(xué)習(xí)的擬合曲線更符合數(shù)據(jù)分布,同時(shí)Mixup還實(shí)現(xiàn)了傳統(tǒng)正則化對(duì)于網(wǎng)絡(luò)的調(diào)節(jié)作用[13]。Mixup[12]中對(duì)于兩個(gè)帶標(biāo)簽的樣本(x1,P1)和(x2,P2),其混合后的目標(biāo)(x′,P′)為
(5)
圖3 一致性約束的實(shí)現(xiàn)方式
未標(biāo)記數(shù)據(jù)的一致性損失Lc1和標(biāo)記數(shù)據(jù)的一致性損失Lc2可以分別表示為
(6)
(7)
(8)
(9)
損失函數(shù)中一致性損失項(xiàng)Lcon為L(zhǎng)c1與Lc2之和,即Lcon=Lc1+Lc2。損失函數(shù)中的未標(biāo)記數(shù)據(jù)分類(lèi)損失和一致性損失通過(guò)L2損失函數(shù)計(jì)算。L2損失函數(shù)與交叉熵不同,它是有界的,而且對(duì)完全錯(cuò)誤的判斷不太敏感,經(jīng)常用作半監(jiān)督學(xué)習(xí)中對(duì)未標(biāo)記數(shù)據(jù)預(yù)測(cè)的損失以及預(yù)測(cè)結(jié)果不確定性的度量[14]。
本文將提出的Mean Mixup算法在TensorFlow2.0平臺(tái)上實(shí)現(xiàn),并與Mean teacher[8]、VAT[7]、Π-model[4]、MixMatch[14]以及Pseudo-label[3]算法進(jìn)行了比較。所有算法選擇的網(wǎng)絡(luò)均為“Wide ResNet-28-2”結(jié)構(gòu),但并沒(méi)有使用學(xué)習(xí)率周期表而只使用了學(xué)習(xí)率衰減,選取運(yùn)行100輪后得到的結(jié)果進(jìn)行對(duì)比。Pseudo-label與MixMatch算法的對(duì)比結(jié)果是在TensorFlow2.0平臺(tái)上進(jìn)行復(fù)現(xiàn)得到的,其他算法的實(shí)驗(yàn)結(jié)果來(lái)自文獻(xiàn)[15],選取的對(duì)比指標(biāo)為錯(cuò)誤率。MixMatch算法根據(jù)文獻(xiàn)[14]選擇超參數(shù)與學(xué)習(xí)率,并選取運(yùn)行100輪后的結(jié)果作對(duì)比。
CIFAR10是一個(gè)深度學(xué)習(xí)常用數(shù)據(jù)集,包含50 000張訓(xùn)練樣本以及10 000張測(cè)試樣本,每個(gè)樣本都是32*32的RGB圖片,并且分屬于10個(gè)類(lèi)別,類(lèi)別各自獨(dú)立,不會(huì)產(chǎn)生重疊。遵循常規(guī)半監(jiān)督學(xué)習(xí)的設(shè)置,實(shí)驗(yàn)中使用了4 000個(gè)帶標(biāo)記的樣本。設(shè)定Mean Mixup算法學(xué)習(xí)率為0.002,對(duì)輸入圖片只進(jìn)行了歸一化處理。結(jié)果表明,Pseudo-label、Mean teacher、VAT、Π-model、MixMatch算法的錯(cuò)誤率分別為15.54%、15.87%、13.86%、16.37%、7.24%,而Mean Mixup算法的錯(cuò)誤率僅為6.37%。從實(shí)驗(yàn)結(jié)果可知,在CIFAR10數(shù)據(jù)集中VAT算法比同樣使用一致性正則化作為主要指導(dǎo)思想的Mean teacher算法表現(xiàn)要好,這可能是由于噪聲的方向選擇能夠使得模型更好地學(xué)習(xí)。MixMatch和Mean Mixup算法的錯(cuò)誤率比單純一致性正則化的Mean teacher、VAT以及Π-model算法低,這證明了在半監(jiān)督學(xué)習(xí)中使用熵最小化構(gòu)建偽標(biāo)簽是有效的。為了對(duì)Mean Mixup算法進(jìn)行更詳細(xì)的實(shí)驗(yàn)論證,分別在CIFAR10數(shù)據(jù)集中選擇了250、500、1 000、2 000個(gè)標(biāo)簽進(jìn)行100輪的實(shí)驗(yàn),算法的錯(cuò)誤率結(jié)果分別為18.70%、14.86%、11.42%、7.64%。更少標(biāo)簽數(shù)據(jù)下的實(shí)驗(yàn)結(jié)果表明,在相同網(wǎng)絡(luò)架構(gòu)中,Mean Mixup算法在僅使用2 000個(gè)標(biāo)記樣本的情況下接近甚至超過(guò)經(jīng)典半監(jiān)督算法的表現(xiàn),證明了Mean Mixup算法對(duì)于標(biāo)簽樣本的利用率更高。
SVHN數(shù)據(jù)集來(lái)源于谷歌街景門(mén)牌號(hào)碼,經(jīng)過(guò)裁剪成為32*32的RGB圖片,包含73 257個(gè)訓(xùn)練樣本和26 032個(gè)測(cè)試樣本,被劃分為10個(gè)類(lèi)別,設(shè)定學(xué)習(xí)率為0.002。將Mean Mixup算法與經(jīng)典半監(jiān)督算法在使用4 000個(gè)標(biāo)簽樣本運(yùn)行100輪的實(shí)驗(yàn)結(jié)果進(jìn)行了對(duì)比。結(jié)果表明,Pseudo-label、Mean teacher、VAT、Π-model、MixMatch算法的錯(cuò)誤率分別為5.37%、5.65%、6.31%、7.19%、3.89%,而Mean Mixup算法的錯(cuò)誤率僅為2.87%。在相同的標(biāo)簽數(shù)據(jù)下,Mean Mixup算法的分類(lèi)錯(cuò)誤率較其他半監(jiān)督算法更低,并且相較于使用單一正則化的Pseudo-label等方法優(yōu)勢(shì)較為明顯。同時(shí),在CIFAR10數(shù)據(jù)集中比Mean teacher算法表現(xiàn)更好的VAT算法在SVHN數(shù)據(jù)集中并沒(méi)有體現(xiàn)出優(yōu)勢(shì),表明了在難以獲得足夠多的標(biāo)簽信息的半監(jiān)督學(xué)習(xí)中,只使用一致性正則化或熵最小化很難獲得出眾的結(jié)果。為驗(yàn)證Mean Mixup算法在更少標(biāo)簽數(shù)據(jù)下的表現(xiàn),進(jìn)行了四組少標(biāo)簽(250、500、1 000、2 000個(gè)標(biāo)簽)數(shù)據(jù)實(shí)驗(yàn),算法的錯(cuò)誤率結(jié)果分別為9.13%、8.07%、6.58%、5.30%。Mean Mixup算法在只有2 000個(gè)標(biāo)簽數(shù)據(jù)的情況下依然取得了錯(cuò)誤率為5.30%的成績(jī),這與4 000個(gè)標(biāo)簽數(shù)據(jù)下Pseudo-label算法的結(jié)果相近,且高于VAT和Π-model算法,再一次驗(yàn)證了Mean Mixup算法對(duì)于標(biāo)簽樣本的利用率更高。
為了驗(yàn)證所得偽標(biāo)簽的準(zhǔn)確度,將使用驗(yàn)證集猜測(cè)得到的標(biāo)簽與其自帶的標(biāo)簽進(jìn)行對(duì)比,每隔20輪記錄下準(zhǔn)確率,偽標(biāo)簽準(zhǔn)確率結(jié)果如圖4所示。可以看出,通過(guò)集成數(shù)據(jù)互補(bǔ)信息進(jìn)而獲得低熵偽標(biāo)簽的方法是有效的,且在250個(gè)標(biāo)簽數(shù)據(jù)的情況下所得的偽標(biāo)簽準(zhǔn)確率也達(dá)到了85.78%,與4 000個(gè)標(biāo)簽數(shù)據(jù)下的準(zhǔn)確率差距不大,這表明Mean Mixup算法在標(biāo)簽數(shù)據(jù)稀少的情況下生成偽標(biāo)簽的準(zhǔn)確度依然較高。
圖4 偽標(biāo)簽準(zhǔn)確率結(jié)果
在Mean Mixup算法中有四個(gè)較為重要的超參數(shù),分別為未標(biāo)記數(shù)據(jù)增強(qiáng)次數(shù)K、Mixup中取樣區(qū)間λ以及未標(biāo)記數(shù)據(jù)分類(lèi)損失與一致性損失各自的比重系數(shù)λu和λc。為了更直觀地展示超參數(shù)的選擇,同時(shí)避免超參數(shù)細(xì)微變化所帶來(lái)的不公平的性能比較,僅選擇了四組超參數(shù)在數(shù)據(jù)集中進(jìn)行實(shí)驗(yàn),其中依照MixMatch算法中的設(shè)置使得α=0.75,遵循 Mean teacher算法使得λc=1。對(duì)于每組超參數(shù),均使用4 000個(gè)標(biāo)簽,所應(yīng)用的數(shù)據(jù)預(yù)處理方式與優(yōu)化器都是一致的,最終選取其運(yùn)行100輪后的錯(cuò)誤率來(lái)進(jìn)行對(duì)比。結(jié)果表明:K=1,λu=75時(shí)錯(cuò)誤率為7.37%;K=1,λu=150時(shí)錯(cuò)誤率為7.86%;K=2,λu=75時(shí)錯(cuò)誤率為6.37%;K=2,λu=150時(shí)錯(cuò)誤率為6.65%。從四組不同超參數(shù)對(duì)比實(shí)驗(yàn)結(jié)果中可以看出,未標(biāo)記數(shù)據(jù)增強(qiáng)次數(shù)K對(duì)于結(jié)果的影響較大,這是由于在生成偽標(biāo)簽的過(guò)程中,多個(gè)不同增強(qiáng)實(shí)例的反饋能增強(qiáng)偽標(biāo)簽的準(zhǔn)確度。而一致性損失值在實(shí)驗(yàn)過(guò)程中一直較小,需要使用較大的比重系數(shù)才能使網(wǎng)絡(luò)從一致性損失中進(jìn)行學(xué)習(xí),所以直接從75增大到150對(duì)于實(shí)驗(yàn)結(jié)果的影響也不明顯。因此,對(duì)比實(shí)驗(yàn)中超參數(shù)的選擇為K=2,α=0.75,λu=75,λc=1。
本文針對(duì)以往半監(jiān)督算法往往忽略數(shù)據(jù)互補(bǔ)信息的不足,提出了一種新的半監(jiān)督算法Mean Mixup。該方法能夠有效利用少量標(biāo)簽帶來(lái)的信息,并推廣到未標(biāo)記數(shù)據(jù)上。Mean Mixup算法基于熵最小化與一致性正則化的思想,設(shè)計(jì)了通過(guò)多階段模型共同作用,集成多角度信息從而生成低熵偽標(biāo)簽的方法,并利用一致性正則化優(yōu)化了模型的分類(lèi)性能。在經(jīng)典數(shù)據(jù)集CIFAR10和SVHN上與現(xiàn)有的半監(jiān)督算法進(jìn)行了比較,實(shí)驗(yàn)結(jié)果表明,在相同標(biāo)簽數(shù)的情況下,Mean Mixup算法的分類(lèi)準(zhǔn)確度較之前的半監(jiān)督方法表現(xiàn)更好。即使在更少標(biāo)簽數(shù)據(jù)的情況下,Mean Mixup算法獲得的準(zhǔn)確度也超過(guò)了之前使用單一正則化的半監(jiān)督方法。本文還驗(yàn)證了生成偽標(biāo)簽的準(zhǔn)確度,發(fā)現(xiàn)即使在標(biāo)簽數(shù)據(jù)稀少的情況下,生成偽標(biāo)簽的準(zhǔn)確度依然較高,表明Mean Mixup在解決半監(jiān)督學(xué)習(xí)問(wèn)題上是有效的,且集成數(shù)據(jù)信息生成偽標(biāo)簽的方法是正確的。