吳豪杰 ,王妍潔 ,蔡文炳 ,王 飛 ,劉 洋 ,蒲 鵬 ,林紹輝
(1.中國(guó)電子科技集團(tuán)公司第二十七研究所,鄭州 450047;2.北京跟蹤與通信技術(shù)研究所,北京 100094;3.中國(guó)人民解放軍63726 部隊(duì),銀川 750004;4.華東師范大學(xué) 計(jì)算機(jī)科學(xué)與技術(shù)學(xué)院,上海 200062;5.華東師范大學(xué) 數(shù)據(jù)科學(xué)與工程學(xué)院,上海 200062)
近年來,隨著深度學(xué)習(xí)與圖形處理器(Graphics Processing Unit,GPU)硬件的不斷發(fā)展,卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks,CNNs)已經(jīng)在諸多人工智能領(lǐng)域取得了顯著的成效,如區(qū)塊鏈[1]、圖像分類[2]、目標(biāo)檢測(cè)[3]等.得益于其大規(guī)模的數(shù)據(jù)量與強(qiáng)大的特征提取能力,CNNs 在某些任務(wù)上甚至已經(jīng)超過了人類識(shí)別的準(zhǔn)確率[4].同時(shí),GPU 硬件的高速發(fā)展大大提高了網(wǎng)絡(luò)模型的計(jì)算效率.
隨著網(wǎng)絡(luò)模型性能的提升,其計(jì)算開銷與存儲(chǔ)量也在不斷增加.如AlexNet[2]模型,其具有0.61 億網(wǎng)絡(luò)參數(shù)和7.29 億次浮點(diǎn)計(jì)算量(Floating-point Operations per Second,FLOPs),占用約240 MB 的存儲(chǔ)空間.對(duì)于被廣為使用的152 層殘差網(wǎng)絡(luò)(Residual Network-152,ResNet-152)[4]具有0.57 億網(wǎng)絡(luò)參數(shù)和113 億次浮點(diǎn)計(jì)算量,占用約230 MB 的存儲(chǔ)空間.龐大的網(wǎng)絡(luò)參數(shù)意味著更大的內(nèi)存占用,而巨大的浮點(diǎn)計(jì)算量意味著高昂的訓(xùn)練代價(jià)與較小的推理速度.這使得如此高存儲(chǔ)、高功耗模型無法直接在資源有限的應(yīng)用場(chǎng)景下應(yīng)用,如手機(jī)、無人機(jī)、機(jī)器人等邊緣嵌入式設(shè)備.因此,在保持模型識(shí)別準(zhǔn)確率的前提下,對(duì)于網(wǎng)絡(luò)模型進(jìn)行壓縮與加速,以適應(yīng)邊緣設(shè)備的實(shí)際要求,成為了當(dāng)前計(jì)算機(jī)視覺領(lǐng)域火熱的研究課題.與此同時(shí),也有研究表明[5],在巨大的網(wǎng)絡(luò)參數(shù)內(nèi)部,并不是所有的結(jié)構(gòu)和參數(shù)對(duì)于網(wǎng)絡(luò)的識(shí)別預(yù)測(cè)能力都起到?jīng)Q定性作用,這使得模型壓縮技術(shù),即移除冗余性參數(shù)和計(jì)算量成為了一種有效的解決方案.
當(dāng)前主流的模型壓縮方法可以分為5 種,分別為參數(shù)剪枝、參數(shù)量化、低秩分解、輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì)和知識(shí)蒸餾(Knowledge Distillation,KD).知識(shí)蒸餾方法可以直接設(shè)定壓縮后模型的結(jié)構(gòu)、計(jì)算量和參數(shù)量,以及不引入額外的計(jì)算算子,這使得知識(shí)蒸餾技術(shù)得到了廣泛關(guān)注.因此,本文也著重研究基于知識(shí)蒸餾的模型壓縮方法.知識(shí)蒸餾方法將較大和較小的網(wǎng)絡(luò)分別定義為教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò) (也稱之為壓縮后網(wǎng)絡(luò)).其主要思想在于,通過最小化該兩個(gè)網(wǎng)絡(luò)輸出分布差異,來實(shí)現(xiàn)網(wǎng)絡(luò)間的知識(shí)遷移,使得學(xué)生網(wǎng)絡(luò)盡可能地獲得教師網(wǎng)絡(luò)的知識(shí),提高學(xué)生網(wǎng)絡(luò)的準(zhǔn)確率.從而,學(xué)生網(wǎng)絡(luò)可以在維持其參數(shù)量不變的情況下提升性能,盡可能逼近甚至有可能超越教師網(wǎng)絡(luò)的性能.傳統(tǒng)的知識(shí)蒸餾方法是將網(wǎng)絡(luò)的輸出分布作為知識(shí)在網(wǎng)絡(luò)間進(jìn)行遷移,隨著該研究領(lǐng)域的進(jìn)一步發(fā)展,研究發(fā)現(xiàn)[6],利用其他一些具有代表性的表征信息或知識(shí)在網(wǎng)絡(luò)間進(jìn)行遷移或蒸餾,可以獲得比傳統(tǒng)知識(shí)蒸餾方法更好的效果.知識(shí)蒸餾方法大致又可以分為: ①基于網(wǎng)絡(luò)輸出層的知識(shí)蒸餾方法;② 基于網(wǎng)絡(luò)中間層的知識(shí)蒸餾方法;③基于樣本關(guān)系之間的知識(shí)蒸餾方法.
本文提出了一種新的基于隱層相關(guān)聯(lián)算子的知識(shí)蒸餾 (Correlation Operation Based Knowledge Distillation,CorrKD) 方法,通過計(jì)算教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)各自隱含層之間的關(guān)聯(lián)性,挖掘出更有效的知識(shí)表征,從而將教師的知識(shí)表征遷移到學(xué)生的知識(shí)表征中,提高學(xué)生網(wǎng)絡(luò)的判別性.該方法的核心是利用了被廣泛應(yīng)用于光流[7-8]、圖像匹配[9]等領(lǐng)域內(nèi)的相關(guān)聯(lián)算子,用于提取網(wǎng)絡(luò)中間層的知識(shí)表征.相關(guān)聯(lián)算子的特性在于,可以很好地表征兩個(gè)特征之間的匹配程度,并反映其特征的變化過程.首先,本文對(duì)于網(wǎng)絡(luò)中每個(gè)階段的輸入特征與輸出特征,利用相關(guān)聯(lián)算子進(jìn)行建模與知識(shí)提取,有效獲得了圖像特征的學(xué)習(xí)變化信息.然后,將教師網(wǎng)絡(luò)每階段通過相關(guān)聯(lián)算子提取出的表征信息作為知識(shí),遷移到學(xué)生網(wǎng)絡(luò)中,提升學(xué)生網(wǎng)絡(luò)判別性和學(xué)習(xí)有效性.
在CIFAR-10 和CIFAR-100 分類數(shù)據(jù)集評(píng)測(cè)結(jié)果中,相比其他中間層知識(shí)蒸餾方法,本文所提出的方法取得了較好的效果.同時(shí),本文所提出的方法在減小網(wǎng)絡(luò)的計(jì)算量和參數(shù)量的同時(shí),能夠有效逼近原始網(wǎng)絡(luò)的準(zhǔn)確率.
除本文將詳細(xì)介紹的知識(shí)蒸餾方法外,其他主流的模型壓縮方法有: ①參數(shù)剪枝[10-11],該方法的主要思想在于,通過對(duì)已訓(xùn)練好的深度神經(jīng)網(wǎng)絡(luò)模型移除冗余、信息量較少的權(quán)值,減少網(wǎng)絡(luò)模型的參數(shù),進(jìn)而增大模型的計(jì)算速度和減小模型所占用的存儲(chǔ)空間,實(shí)現(xiàn)模型壓縮;② 參數(shù)量化[12-14],該方法的主要思想是一種將多個(gè)參數(shù)實(shí)現(xiàn)共享的直接表示形式,其核心思想在于,利用較低的位來代替原始32 位的浮點(diǎn)型參數(shù),從而縮減網(wǎng)絡(luò)存儲(chǔ)和浮點(diǎn)計(jì)算次數(shù);③低秩分解[15-16],該方法的核心思想在于,利用矩陣或張量的分解技術(shù)對(duì)網(wǎng)絡(luò)模型中的原始卷積核進(jìn)行分解.一般來說,卷積計(jì)算是網(wǎng)絡(luò)中復(fù)雜度最高且最為普遍的計(jì)算操作,通過對(duì)張量進(jìn)行分解從而減小模型內(nèi)部冗余性,實(shí)現(xiàn)模型壓縮;④ 輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì),輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì)的方法主要是改變了卷積神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)特征,提出了一些新穎的輕量計(jì)算模塊或操作,從而精簡(jiǎn)網(wǎng)絡(luò)結(jié)構(gòu),增大處理速度.如基于深度可分離卷積的MobileNet[17],利用神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索得到的EfficientNet[18]等.
知識(shí)蒸餾方法[19]指利用教師網(wǎng)絡(luò)中的知識(shí)表征為學(xué)生網(wǎng)絡(luò)提供指導(dǎo),以提高學(xué)生網(wǎng)絡(luò)的性能.傳統(tǒng)的知識(shí)蒸餾方法通過最小化教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)類別輸出分布的KL (Kullback-Leibler)散度來實(shí)現(xiàn)蒸餾.除了在輸出層外,網(wǎng)絡(luò)中間層的特征信息也被應(yīng)用到知識(shí)蒸餾方法中.
中間層特征知識(shí)的構(gòu)造.Romero 等[20]提出的FitNet 是較早利用中間特征信息進(jìn)行知識(shí)蒸餾的方法,其目標(biāo)是使經(jīng)過奇異值分解的學(xué)生網(wǎng)絡(luò)盡可能學(xué)習(xí)教師網(wǎng)絡(luò)中間層的特征信息.隨后,Zagoruyko 等[21]提出在網(wǎng)絡(luò)中間層引入注意力機(jī)制,將每層的注意力特征作為可學(xué)習(xí)的知識(shí)遷移到學(xué)生網(wǎng)絡(luò)中.近年來,隨著自注意力模型被廣泛運(yùn)用到變形器[22]中,進(jìn)而獲得人工智能領(lǐng)域各項(xiàng)任務(wù)的性能突破,相關(guān)知識(shí)蒸餾方法[23-24]通過對(duì)齊教師與學(xué)生的自注意力矩陣實(shí)現(xiàn)知識(shí)遷移.Yim 等[25]提出了FSP (Flow of Solution Procedure)方法,將網(wǎng)絡(luò)中每層之間的數(shù)據(jù)流動(dòng)關(guān)系作為知識(shí),由教師網(wǎng)絡(luò)遷移到學(xué)生網(wǎng)絡(luò)中.除此之外,樣本之間的關(guān)系特征也被發(fā)現(xiàn)可以凝煉出更好的知識(shí)表示.例如,Park 等[26]提出RKD (Relational Knowledge Distillation)知識(shí)蒸餾框架,對(duì)于不同樣本網(wǎng)絡(luò)輸出的結(jié)構(gòu)關(guān)系進(jìn)行建模,將關(guān)系特征進(jìn)行知識(shí)遷移.此外,Liu 等[27]通過將教師網(wǎng)絡(luò)特征空間映射到由頂點(diǎn)與邊構(gòu)成的圖表示空間中,然后對(duì)齊教師與學(xué)生網(wǎng)絡(luò)的頂點(diǎn)以及它們邊的對(duì)應(yīng)信息實(shí)現(xiàn)知識(shí)蒸餾.Tung 等[28]利用網(wǎng)絡(luò)中間層每個(gè)樣本之間的相似度信息進(jìn)行知識(shí)遷移.Kim 等[29]提出在教師網(wǎng)絡(luò)的最后一層特征中提取便于學(xué)生網(wǎng)絡(luò)理解的轉(zhuǎn)移因子,將知識(shí)傳遞給學(xué)生網(wǎng)絡(luò).對(duì)于教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中間層特征不一致的情況,Heo 等[30]提出了使用 1×1 卷積進(jìn)行維度對(duì)齊,并構(gòu)建教師網(wǎng)絡(luò)激活邊界作為中間層知識(shí)遷移到學(xué)生網(wǎng)絡(luò)中.不僅如此,特征圖的雅可比梯度信息[31]也可以作為中間層特征知識(shí)表示.近年來,出現(xiàn)了一些在輸出層特征進(jìn)行對(duì)比學(xué)習(xí)[32]或基于自監(jiān)督[33]的知識(shí)蒸餾方法,分別用于挖掘教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)對(duì)于不同樣本之間的關(guān)系,從而將教師網(wǎng)絡(luò)的關(guān)系知識(shí)遷移到學(xué)生網(wǎng)絡(luò)中.不同于以上知識(shí)蒸餾方法,本文所提出的基于相關(guān)聯(lián)系數(shù)的知識(shí)蒸餾方法作用于每階段中間層特征信息,從而獲得每階段中間特征變化信息,能更好構(gòu)建知識(shí)表征,提高學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)性能.
使用優(yōu)化訓(xùn)練策略進(jìn)行中間層知識(shí)蒸餾.近年來,大量生成對(duì)抗思想被應(yīng)用到中間層知識(shí)蒸餾中,提高知識(shí)蒸餾性能.例如,Su 等[34]引入了任務(wù)驅(qū)動(dòng)的注意力機(jī)制,將教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)各自高層信息嵌入低層中,實(shí)現(xiàn)中間層信息的遷移,同時(shí)加入判別器用于增強(qiáng)學(xué)生網(wǎng)絡(luò)最后輸出特征的魯棒性.類似地,Shen 等[35]提出了基于對(duì)抗學(xué)習(xí)的多教師網(wǎng)絡(luò)集成蒸餾框架,利用自適應(yīng)池化操作對(duì)齊一個(gè)學(xué)生與多個(gè)教師集成網(wǎng)絡(luò)的中間層輸出維度,同時(shí)利用生成對(duì)抗策略對(duì)池化的中間層特征進(jìn)行對(duì)抗訓(xùn)練,提高了知識(shí)蒸餾性能.Chung 等[36]提出了基于中間層特征圖的在線對(duì)抗蒸餾框架,設(shè)計(jì)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的判別器,用于共同學(xué)習(xí)和對(duì)齊這兩個(gè)網(wǎng)絡(luò)在訓(xùn)練過程中的特征圖分布的變化情況.Jin 等[37]提出了一種路線限制優(yōu)化策略,預(yù)先設(shè)定好教師網(wǎng)絡(luò)訓(xùn)練的中間模型狀態(tài),并通過逐步對(duì)齊學(xué)生網(wǎng)絡(luò)與其中間層特征分布,使得學(xué)生網(wǎng)絡(luò)獲得更好的局部最優(yōu)解.
知識(shí)蒸餾方法[19]認(rèn)為在數(shù)據(jù)的網(wǎng)絡(luò)輸出中,每一個(gè)數(shù)據(jù)的預(yù)測(cè)概率結(jié)果都可以看作是一個(gè)分布,不僅關(guān)注于置信度最高的類別所對(duì)應(yīng)的結(jié)果,而且對(duì)于預(yù)測(cè)錯(cuò)誤結(jié)果的置信度概率也具備一定的網(wǎng)絡(luò)知識(shí).在傳統(tǒng)分類任務(wù)所使用的交叉熵?fù)p失函數(shù)中,只會(huì)關(guān)注對(duì)應(yīng)于正確類別的概率值,對(duì)于其他類別所對(duì)應(yīng)的概率是直接丟棄,沒有利用的,Hinton 等[19]將其稱作是暗知識(shí).在知識(shí)蒸餾的過程中,學(xué)生網(wǎng)絡(luò)所學(xué)習(xí)到的,不僅是預(yù)測(cè)正確的類別所對(duì)應(yīng)的概率值結(jié)果,而且包括教師網(wǎng)絡(luò)所學(xué)習(xí)到的暗知識(shí).
在具體的實(shí)現(xiàn)過程中,將教師網(wǎng)絡(luò)記為ft,學(xué)生網(wǎng)絡(luò)記為fs,將輸入記作x,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的模型輸出結(jié)果分別記為zt和zs,且zt=ft(x),zs=fs(x),zt,zs∈Rd,d為總類別數(shù).對(duì)于網(wǎng)絡(luò)得到的輸出分布,利用 Softmax對(duì)此進(jìn)行歸一化,得到概率分布.同時(shí),還引入了溫度分布參數(shù)τ用來平滑該層的輸出分布,以強(qiáng)化網(wǎng)絡(luò)輸出的概率分布中所學(xué)習(xí)到的知識(shí),通過溫度平滑后的網(wǎng)絡(luò)輸出被稱為軟目標(biāo).對(duì)此,以教師網(wǎng)絡(luò)為例,對(duì)于第i個(gè)輸入樣本xi,其軟目標(biāo)用公式表示為
式(2)中:n表示樣本總個(gè)數(shù),KL(ps||pt) 定義為學(xué)生網(wǎng)絡(luò)輸出分布與教師網(wǎng)絡(luò)輸出分布之間差異,具體公式表示為
所以,在學(xué)生網(wǎng)絡(luò)訓(xùn)練的過程中,教師網(wǎng)絡(luò)的軟目標(biāo)與真實(shí)標(biāo)簽共同起到監(jiān)督作用.傳統(tǒng)知識(shí)蒸餾損失函數(shù)為
式(3)中:LCE為傳統(tǒng)的學(xué)生網(wǎng)絡(luò)輸出與真實(shí)標(biāo)簽的交叉熵?fù)p失函數(shù);α為平衡因子,用于權(quán)衡LCE和LKL的重要性比例.
相關(guān)聯(lián)算子[7]被廣泛應(yīng)用到光流、圖像匹配、目標(biāo)跟蹤領(lǐng)域中,用于描述兩張圖像或兩個(gè)特征之間的匹配程度(圖1).對(duì)于三維的圖像特征張量A和B,其尺寸為C×H ×W,C、H和W分別表示其特征圖的通道數(shù)、高度與寬度.特征張量A中給定位置 (i,j)的特征為PA(i,j)∈RC,需要計(jì)算其與特征張量B中所對(duì)應(yīng)位置圖像塊的特征相似度,這里所對(duì)應(yīng)的圖像塊以 (i,j)為中心,大小為k×k,將該區(qū)域內(nèi)的像素位置記為 (i′,j′) ,所對(duì)應(yīng)的特征為PB(i′,j′) ,與PA(i,j)類似,該像素特征均為C維向量.因此,可以通過計(jì)算內(nèi)積的方式得到對(duì)應(yīng)像素特征之間的相似度,由此得到相關(guān)聯(lián)算子φ,其計(jì)算公式為
圖1 相關(guān)聯(lián)算子示意圖Fig.1 Illustration of correlation operation
式(4)中:⊙表示向量?jī)?nèi)積,為歸一化系數(shù).由此,可以得到特征張量A和B之間的相關(guān)聯(lián)算子,可以將其記為φ(A,B)∈Rk2×H×W.所以,對(duì)于給定的兩個(gè)三維圖像特征張量,可以通過計(jì)算像素特征與圖像塊中每個(gè)像素之間的相似度,得到尺寸為k2×H ×W的相關(guān)聯(lián)算子,用于反映特征之間的相似程度或匹配程度.
借助相關(guān)聯(lián)算子,可以計(jì)算網(wǎng)絡(luò)模型隱層中尺度相同的兩個(gè)特征張量之間的特征,用以反映特征的匹配相似程度,并利用其進(jìn)行知識(shí)遷移 (圖2).圖2 中的KL 損失LKL和LCor損失分別被定義于式(2)和式(5)中,xi和分別為第i個(gè)輸入樣本和該樣本增強(qiáng)變化后的表示.
圖2 基于隱層相關(guān)聯(lián)算子蒸餾方法的整體框架Fig.2 Illustration of intermediate CorrKD framework
通常,網(wǎng)絡(luò)模型會(huì)根據(jù)其特征圖空間尺寸大小的不同而劃分成不同的階段,換句話說,在相同的網(wǎng)絡(luò)階段內(nèi),其中間特征的維度尺寸都是相同的.因此,可以將每個(gè)階段的第一層特征與最后一層輸出特征作為相關(guān)聯(lián)算子中的特征張量A和B.該相關(guān)聯(lián)算子的計(jì)算可以很好地反映出模型每個(gè)階段對(duì)于數(shù)據(jù)的處理變化過程,成為非常有效的知識(shí)表征.因此,可以將相關(guān)聯(lián)算子計(jì)算結(jié)果用作知識(shí)蒸餾的表征信息,由教師網(wǎng)絡(luò)對(duì)學(xué)生網(wǎng)絡(luò)進(jìn)行指導(dǎo).假設(shè)網(wǎng)絡(luò)有N個(gè)階段,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的第i個(gè)階段的第一層輸入特征分別記為,最后一層的輸出特征分別記為Fit2和Fis2,其知識(shí)遷移的過程可以利用LCor損失進(jìn)行約束,對(duì)此,基于隱層相關(guān)聯(lián)算子的知識(shí)遷移損失函數(shù)可以表示為
式(5)中:λi,i=1,2,···,N表示第i階段的權(quán)重因子,||·||2為L(zhǎng)2范數(shù).為了更好形成多樣的知識(shí)表征,在本文中引入數(shù)據(jù)增強(qiáng)和變化[4](如旋轉(zhuǎn)、翻轉(zhuǎn)、顏色變化等),可以更有效地將隱含層的相關(guān)聯(lián)算子的知識(shí)遷移到學(xué)生網(wǎng)絡(luò)中,從而產(chǎn)生更好的效果.通過結(jié)合了教師網(wǎng)絡(luò)中傳統(tǒng)知識(shí)蒸餾損失函數(shù) (式 (3))和隱層相關(guān)聯(lián)算子的知識(shí)遷移損失函數(shù) (式 (5)),可以得到該知識(shí)蒸餾方法完整的訓(xùn)練損失函數(shù)公式為
式(6)中:β為超參數(shù),用于控制3 個(gè)損失 (LCE、LKL和LCor) 的平衡性.在訓(xùn)練過程中,本文直接使用梯度下降法優(yōu)化式 (6),選擇學(xué)生網(wǎng)絡(luò)進(jìn)行測(cè)試,并計(jì)算出學(xué)生網(wǎng)絡(luò)的準(zhǔn)確率作為該方法的評(píng)測(cè)效果.
本文在兩個(gè)經(jīng)典的分類公開數(shù)據(jù)集CIFAR-10 與CIFAR-100 上進(jìn)行了實(shí)驗(yàn),均包含6 萬張長(zhǎng)寬尺寸均為32 的圖像,其中5 萬張用于訓(xùn)練,剩下的1 萬張用于測(cè)試,他們的分類類別數(shù)分別為10和100.
本文所提出的方法使用Pytorch 在單張GPU 上進(jìn)行實(shí)現(xiàn),對(duì)于兩種數(shù)據(jù)集均采用隨機(jī)梯度下降方法進(jìn)行優(yōu)化.在訓(xùn)練中,圖像批量大小設(shè)置為64,學(xué)習(xí)率設(shè)置為0.05,動(dòng)量設(shè)置為0.9,權(quán)重衰減系數(shù)為0.000 5.對(duì)于教師網(wǎng)絡(luò),利用標(biāo)準(zhǔn)交叉熵?fù)p失函數(shù)進(jìn)行訓(xùn)練,訓(xùn)練迭代次數(shù)為240,其學(xué)習(xí)率分別在第150、180、210 次迭代時(shí),分別縮小為原來的1/10,訓(xùn)練完成后將教師網(wǎng)絡(luò)進(jìn)行保存,存儲(chǔ)于本地磁盤中.
對(duì)于學(xué)生網(wǎng)絡(luò),需要先讀取教師網(wǎng)絡(luò)的模型參數(shù),利用所提出的損失函數(shù)式 (6) 進(jìn)行訓(xùn)練,模型訓(xùn)練優(yōu)化器與學(xué)習(xí)率設(shè)置均與教師網(wǎng)絡(luò)一致,訓(xùn)練迭代次數(shù)設(shè)為300,其學(xué)習(xí)率分別在第180,220,260 次迭代時(shí),分別縮小為原來的1/10.
在相關(guān)聯(lián)算子的計(jì)算過程中,需要引入數(shù)據(jù)增強(qiáng),首先,對(duì)于圖像進(jìn)行隨機(jī)旋轉(zhuǎn)與翻轉(zhuǎn).其次,在圖像色彩上從灰度轉(zhuǎn)化、色彩抖動(dòng)、高斯模糊等操作中隨機(jī)選取一種對(duì)圖像進(jìn)行色彩上的增強(qiáng).在相關(guān)聯(lián)算子的計(jì)算過程中,參數(shù)k=7,對(duì)于所選取的網(wǎng)絡(luò)模型,其結(jié)構(gòu)均為4 個(gè)階段,也就是式 (5) 中的N=4,同時(shí)將每個(gè)階段的權(quán)重設(shè)為相等,也就是λi=1 .設(shè)置式 (1) 中的τ=4 .最后,設(shè)置式 (6)中的α=0.2,β=5 .
本文所提出的方法在多種模型結(jié)構(gòu)上進(jìn)行實(shí)驗(yàn)驗(yàn)證,選取ResNet[4]與WideResNet[38](WRN)作為網(wǎng)絡(luò)主干,并在多種教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)組合上進(jìn)行實(shí)驗(yàn).表1 總結(jié)了4 組教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)的參數(shù)量與計(jì)算量信息.在表2 和表3 中,總結(jié)了本文所提出方法的性能效果,其中本文所提出的基于隱層相關(guān)聯(lián)算子的知識(shí)蒸餾方法記為CorrKD,僅利用中間層式 (5) 與交叉熵?fù)p失訓(xùn)練得到的學(xué)生網(wǎng)絡(luò)方法簡(jiǎn)稱為Corr,KD 表示僅利用式 (3) 進(jìn)行訓(xùn)練的傳統(tǒng)知識(shí)蒸餾訓(xùn)練結(jié)果.注意到表2 與表3 中的第3 和第4 列分別表示教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)在正常情況下訓(xùn)練得到的基準(zhǔn)準(zhǔn)確率結(jié)果 (即只使用交叉熵?fù)p失函數(shù)).KD 展示了學(xué)生網(wǎng)絡(luò)在利用式 (3) 訓(xùn)練得到的傳統(tǒng)知識(shí)蒸餾方法的結(jié)果.
表1 實(shí)驗(yàn)所用模型參數(shù)量與計(jì)算量信息Tab.1 Model parameters and FLOPs information used in the experiment
從實(shí)驗(yàn)結(jié)果來看,單純基于中間隱層相關(guān)聯(lián)算子的知識(shí)遷移方法可以對(duì)于學(xué)生網(wǎng)絡(luò)的訓(xùn)練帶來一定的促進(jìn)作用,但效果并不明顯.通過結(jié)合了輸出層的傳統(tǒng)知識(shí)蒸餾方法KD 之后,在學(xué)生網(wǎng)絡(luò)的分類正確率上,獲得了很好的性能提升.在蒸餾教師網(wǎng)絡(luò)WRN40-2 時(shí),在CIFAR-10 上學(xué)生網(wǎng)絡(luò)WRN16-2 的網(wǎng)絡(luò)參數(shù)和網(wǎng)絡(luò)計(jì)算量都約為原來教師網(wǎng)絡(luò)WRN40-2 的31.8%,即參數(shù)量 (教師網(wǎng)絡(luò)參數(shù)量為2.2 M,學(xué)生網(wǎng)絡(luò)參數(shù)量為0.7 M,教師網(wǎng)絡(luò)計(jì)算量為329.0 M,學(xué)生網(wǎng)絡(luò)計(jì)算量為 101.6 M).如表2 所示,由本文所提出的CorrKD 方法得到的學(xué)生網(wǎng)絡(luò)準(zhǔn)確率只下降了0.5 百分點(diǎn) (教師網(wǎng)絡(luò)準(zhǔn)確率為95.2%,學(xué)生網(wǎng)絡(luò)使用CorrKD 方法準(zhǔn)確率為94.7%).對(duì)于類別個(gè)數(shù)更多的CIFAR-100 上,同樣蒸餾的網(wǎng)絡(luò)選擇,由本文所提出的CorrKD 方法壓縮WRN40-2 后的網(wǎng)絡(luò)計(jì)算量和參數(shù)量約是壓縮前的31.8% (表1),準(zhǔn)確率只下降1 百分點(diǎn) (表3 中教師網(wǎng)絡(luò)準(zhǔn)確率為76.8%,由CorrKD 方法得到的準(zhǔn)確率為75.8%).由此可見,本文所提出的方法在準(zhǔn)確率有限下降的情況下,模型能夠獲得顯著的壓縮比,壓縮后形成的學(xué)生網(wǎng)絡(luò)能夠有效嵌入受限移動(dòng)設(shè)備端中.
表2 CorrKD 在CIFAR-10 上實(shí)驗(yàn)結(jié)果Tab.2 Experimental results of CorrKD on CIFAR-10
表3 CorrKD 在CIFAR-100 上實(shí)驗(yàn)結(jié)果Tab.3 Experimental results of CorrKD on CIFAR-100
在CIFAR-100 上,也可視化了本文所提出的CorrKD 方法對(duì)于蒸餾WRN16-2 的訓(xùn)練損失的變化以及測(cè)試準(zhǔn)確率的變化.如圖3 所示,隨著訓(xùn)練的回合數(shù)的增加,完整訓(xùn)練損失Lo逐步減小,同時(shí)測(cè)試準(zhǔn)確率逐漸提升.該訓(xùn)練結(jié)果驗(yàn)證了本文所提出的方法在訓(xùn)練上的穩(wěn)定性與有效性.
圖3 完整訓(xùn)練損失Lo 和測(cè)試準(zhǔn)確率變化曲線Fig.3 Curves of overall training loss Loa nd test accuracy with respect to the epoch number
在CIFAR-100 評(píng)測(cè)數(shù)據(jù)集上并以WideResNet 為主干網(wǎng)絡(luò),將本文所提出的方法與其他經(jīng)典基于中間層的知識(shí)蒸餾方法進(jìn)行對(duì)比,包括FitNet[20],AT (Attention Transfer)[21],SP (Similarity-Preserving)[28]和FT (Factor Transfer)[29].為保證公平性,上述中間層蒸餾方法都展示與傳統(tǒng)KD 相結(jié)合訓(xùn)練的實(shí)驗(yàn)結(jié)果,各方法所得到的結(jié)果對(duì)比如表4 所示.從實(shí)驗(yàn)結(jié)果來看,本文所提出的知識(shí)蒸餾方法在WideResNet 模型結(jié)構(gòu)上,和其他中間層的知識(shí)蒸餾方法相比,取得了較好水平.例如,在學(xué)生網(wǎng)絡(luò)為WRN16-1 時(shí),本文所提出的方法和AT 方法相比,準(zhǔn)確率提高了0.1 百分點(diǎn) (CorrKD 準(zhǔn)確率為74.6%,AT 準(zhǔn)確率為74.5%),同時(shí),與教師網(wǎng)絡(luò)WRN40-2 相比,準(zhǔn)確率降低2.2 百分點(diǎn) (CorrKD 準(zhǔn)確率為74.6%,WRN40-2 準(zhǔn)確率為76.8% (表3)).
表4 CorrKD 與其他知識(shí)蒸餾方法在CIFAR-100 上準(zhǔn)確率對(duì)比Tab.4 Accuracy comparison between different KD methods and CorrKD on CIFAR-100
本節(jié)主要探索部分超參數(shù)對(duì)于實(shí)驗(yàn)效果的影響,主要包括相關(guān)聯(lián)算子中參數(shù)k的影響以及完整的訓(xùn)練損失函數(shù)中參數(shù)α,β的影響.實(shí)驗(yàn)均在CIFAR-100 上進(jìn)行,教師網(wǎng)絡(luò)結(jié)構(gòu)選取WRN40-2,學(xué)生網(wǎng)絡(luò)結(jié)構(gòu)選取WRN16-2.對(duì)于3 組參數(shù)的實(shí)驗(yàn)結(jié)果分別如表5 和表6 所示,“教師網(wǎng)絡(luò)→學(xué)生網(wǎng)絡(luò)”表示教師網(wǎng)絡(luò)蒸餾學(xué)生網(wǎng)絡(luò)所使用的網(wǎng)絡(luò)模型.在k相關(guān)的實(shí)驗(yàn)中,固定α=0.2,β=5 ;同理,在α與β相關(guān)的實(shí)驗(yàn)中,固定其他兩個(gè)參數(shù).從實(shí)驗(yàn)結(jié)果看出,實(shí)驗(yàn)中所選取的參數(shù)k=7,α=0.2,β=5 均為最佳參數(shù).
表5 相關(guān)聯(lián)算子參數(shù) k 實(shí)驗(yàn)結(jié)果對(duì)比Tab.5 Comparison with different values of k in the correlation operation
表6 完整訓(xùn)練損失 Lo 中參數(shù) α ,β 實(shí)驗(yàn)結(jié)果對(duì)比Tab.6 Comparison with different values of α ,β in the overall training lossLo
本文提出了一種新的基于隱層相關(guān)聯(lián)算子的知識(shí)蒸餾方法,首次將用于光流中的相關(guān)聯(lián)算子計(jì)算操作運(yùn)用到模型中間隱含層的特征提取中,相關(guān)聯(lián)算子可以對(duì)特征之間的匹配程度或變化過程進(jìn)行有效建模,反映模型中間層的表征信息.同時(shí)在數(shù)據(jù)增強(qiáng)的作用下,進(jìn)行中間層的知識(shí)遷移,結(jié)合輸出層的傳統(tǒng)知識(shí)蒸餾方法,構(gòu)成了本文所提出的全新知識(shí)蒸餾框架.實(shí)驗(yàn)表明,本文所提出的知識(shí)蒸餾方法在兩種公開數(shù)據(jù)集上均取得了優(yōu)越性能,并在WideResNet 模型上取得了同類型中間層知識(shí)蒸餾方法中的最優(yōu)水平.在未來的研究中,可以考慮將該模型中間層表征知識(shí)提取方法利用到更多視覺領(lǐng)域下游任務(wù)的蒸餾中,并在多個(gè)任務(wù)上驗(yàn)證本文所提出方法的壓縮效果.