葛海波,周 婷,黃朝鋒,李 強
西安郵電大學 電子工程學院,西安 710000
近年來,卷積神經網絡(convolutional neural network,CNN[1]引起了業(yè)內學者廣泛的關注并得到了快速的發(fā)展,在圖像分類[2,4]、目標檢測[5,7]等任務中的表現取得了巨大的進步,但高性能的深度學習網絡常常是計算密集型和參數密集型的,這一特點限制了網絡在低資源設備上的應用。因此,各種模型壓縮技術應運而生[8-9]。
知識蒸餾(KD)[10]因其在壓縮率和精度保留方面的優(yōu)異表現引起研究人員們的關注,繼而誕生了FitNets算法使學生網絡盡量學習教師網絡的隱藏特征值[11],該算法在性能方面的表現并不突出,但提供了特征提取這個很有價值的蒸餾思想,為知識蒸餾的發(fā)展打開了新思路。文獻[12]引入注意力機制(AT),利用模型的注意力圖指導訓練,使簡單模型生成的注意力圖與復雜模型相似,以此來提升學生模型的學習效果。文獻[7]提出包含特征交互和特征融合兩個模塊的注意力特征交互蒸餾(AFID),該方法利用交互教學機制完成兩個網絡之間的發(fā)送、接收和傳遞反饋,促進學生網絡的學習效果。文獻[13]在多教師模型的基礎上,利用多組策略傳遞各個教師網絡的中間特征(AMTML-KD),使學生網絡自適應地學習不同等級的特征知識。文獻[14]使用對抗訓練的方式在線互相學習特征圖的分布,同時結合logit蒸餾進一步提高在分類任務上的準確率(AFD)。
上述蒸餾算法在圖像分類任務上的精度和壓縮效果都有很好的表現,但讓學生網絡直接學習教師網絡的中間層特征信息并不能很好地訓練學生網絡的擬合能力,導致其在目標檢測、文本檢索和語義分割等任務上的表現并不突出且算法并不支持在不同架構、不同維度的網絡之間轉移知識,對知識蒸餾技術的發(fā)展和應用帶來一定的局限性。
針對上述問題,本文提出一種適用于多任務的基于特征分布蒸餾算法,主要有以下幾點工作:
(1)引入互信息對輸入圖像的特征分布建模,利用條件概率分布代替互信息中的聯合概率密度分布,使得到的特征分布結果描述更準確。
(2)在知識蒸餾中將教師網絡的特征分布作為學生網絡的學習知識,并在損失函數的設計中引入最大平均差異(MMD),以最小化教師網絡和學生網絡特征分布之間的距離,使學生網絡更好地擬合教師網絡的特征分布情況。
(3)在知識蒸餾的基礎上使用toeplitz 矩陣實現學生網絡的全連接層權重參數共享,進一步實現模型的壓縮與加速。
(4)在圖像分類、目標檢測和語義分割等不同的任務上進行了實驗,其在CIFAR-100、ImageNet 和VOC 數據集上都取得了很好的蒸餾效果。
以FitNets 蒸餾為代表的特征蒸餾算法中,通常是將每個位置空間的激活作為一個特征,每個濾波器中的二維激活映射作為神經元表示某個空間樣本,僅僅反映了CNN是如何解釋該樣本的,忽略了CNN在處理該樣本時更關注哪部分以及更強調使用哪種類型的激活模式。因此,直接匹配教師網絡的特征圖并不能很好地鍛煉學生網絡的特征擬合能力,因為它忽略了空間中樣本的特征分布情況。
從上述問題出發(fā),本文提出基于特征分布蒸餾算法,通過條件互信息分別對教師網絡和學生網絡的特征分布進行建模,利用MMD 度量兩者之間的距離,使學生網絡更好地擬合教師網絡的特征分布情況,結合MMD匹配損失和目標蒸餾損失函數來設計本文算法的損失函數。算法原理圖如圖1所示,其中灰色框內表示算法損失函數,兩個正方形框分別表示教師與學生網絡在經過最小化MMD匹配前后的特征分布狀態(tài),圖中的三角形和圓圈分別對應教師網絡與學生網絡的特征分布,對比兩種情況下的特征分布情況,可以直觀地看出經過特征分布蒸餾后的學生網絡的特征分布與教師網絡的特征分布更為相近。因此,從理論上分析,特征分布蒸餾算法可以使學生網絡更好地擬合教師網絡的特征分布,從而提高學生網絡的特征提取能力。
圖1 特征分布知識蒸餾原理圖Fig.1 Principle diagram of characteristics of distributed knowledge distillation
知識蒸餾與遷移學習的思想較為相似,但知識蒸餾更加強調知識的遷移而非權重的遷移。目標知識蒸餾主要側重于學習大型網絡的輸出知識,如邏輯單元和類概率,類概率是由Softmax激活函數轉化的邏輯單元,其公式如式(1)所示:
其中,Zi表示第i類的邏輯單元值,Pi是第i類的類概率,k表示不同種類的數量,T表示溫度系數,當T=1時則表示輸出的類概率,當T=∞時,其為輸出的邏輯單元[10]。
目標蒸餾時的損失函數如式(2)所示:
其中,N代表小批量的數量,LCE表示交叉熵,δ表示Softmax函數,yi表示樣本i的真實標簽,ZS∈RC,ZT∈RC分別為C 類任務上兩個網絡的Softmax 輸出。該損失函數主要分為兩部分:當T=∞時,學生網絡和教師網絡分類預測的交叉熵;當T=1 時,真實標簽與學生網絡的分類預測交叉熵。
1.2.1 預備知識
知識蒸餾是將兩個模型之間傳遞的知識表示為集合T,T={T1,T2,…,TN},稱該集合為轉移集。設x=f(t),y=g(t,W),x、y分別用來表示教師模型和學生模型的輸出,其中W為學生模型的參數。在知識蒸餾過程中,g(·)中的參數W用來學習并模擬f(·)的表示。
互信息是利用兩個隨機數據點間的相互關系描述特征空間的幾何形狀,同時也表示兩個隨機變量間相互依賴的程度,決定著聯合分布P(x,y)和分解的邊緣分布乘積P(x)P(y) 之間相似程度。設隨機變量為X、Y,兩者之間的聯合概率密度可表示為ρ(x,y)=P(X=x,Y=y),則X與Y之間的互信息I(X,Y)的計算公式如式(3)所示:
其中,P(x,y)表示隨機變量X、Y的聯合概率密度,P(x)和P(y)分別表示X和Y的邊緣概率密度。
1.2.2 條件互信息
本文采用互信息對教師網絡和學生網絡的特征分布情況進行建模,為更好地擬合特征分布情況,互信息中利用核密度估計近似聯合密度概率,教師網絡和學生網絡中的聯合概率密度分別表示為式(4)和(5)所示:
其中,K(a,b;2δ2t)表示一個寬度為δt的對稱核,a、b分別表示向量。
雖然使用聯合概率密度分布來建模數據的幾何分布并進行知識蒸餾可以解決傳統(tǒng)知識蒸餾存在的許多問題,但考慮到樣本局部區(qū)域內的數據存在相似性這一特點,為更準確地描述樣本局部區(qū)域,本文提出條件互信息的概念。利用條件概率分布代替原有的聯合概率密度函數,通過條件互信息對特征分布情況進行建模,教師模型和學生模型中的條件概率分布可分別表示為式(6)、(7):
為驗證條件互信息的有效性,分別對基于條件互信息和傳統(tǒng)互信息的特征提取做了對比實驗,結果如圖2所示,其中圖2(a)為輸入原始圖像,圖2(b)為基于傳統(tǒng)互信息的特征提取,圖2(c)為基于條件互信息的特征提取。圖2(b)、(c)中高亮的部分代表中心特征,可以看到經過條件互信息處理后的模型對樣本中局部區(qū)域的關注度更高,更能捕捉到中心特征附近與之相似的特征信息。
圖2 條件互信息與傳統(tǒng)互信息的特征提取對比圖Fig.2 Comparison of feature extraction between conditional mutual information and traditional mutual information
由于經特征分布蒸餾算法訓練后的學生網絡在提取特征時會更聚焦于樣本的局部區(qū)域,更容易捕捉到樣本的特征區(qū)間,這一特點在圖像處理任務上都有很大的優(yōu)勢。因此,在第3章的實驗中本文將該算法在圖像分類、目標檢測和語義分割任務上進行了實驗,驗證該算法在處理不同任務時的性能表現。同時,由于特征分布蒸餾算法是通過對教師網絡的特征分布進行建模并以此為蒸餾的知識指導學生網絡的學習,而不是讓學生網絡學習教師網絡的中間層特征圖。因此,特征分布算法并不受限于只能在相同網絡架構間進行蒸餾,對此本文在第3章中設置了實驗進行驗證。
其中,k(·,·)是核函數,將樣本向量映射到高階特征空間中。綜上所述,最小化MMD相當于最小化α和β之間的距離,利用這一思想設計MMD匹配損失函數。
由于本文所提的知識蒸餾算法是利用教師網絡的特征分布和Softmax 層輸出概率來指導學生網絡的學習。因此,損失函數的設計可分為以下兩部分:MMD匹配損失函數以及目標蒸餾的交叉熵損失,則損失函數的定義為:
其中,H(·)為標準的交叉熵損失,ytrue表示真標簽,PS為學生網絡的輸出概率,ω為正則項損失函數的權重。
由于現有方法大多數側重于蒸餾技術的提高,通常會忽略學生網絡作為一個完整的神經網絡本身具有的可壓縮空間,因此本文在知識蒸餾的基礎上結合參數共享方法。
參數共享主要是針對學生網絡的全連接層,將全連接層的權重參數表示為m×n維的矩陣W,其中n、m分別表示輸入層和輸出層中神經元的數量,圖3(a)表示輸入層和輸出層之間的神經元連接,圖3(b)表示輸入層與輸出層間的權重參數矩陣。輸入層第i個神經元與輸出層神經元之間的連接可表示為一個n維向量,用ωi表示ωi=(ωi1,ωi2,…,ωin),其中ωij表示連接輸入層第i個神經元和輸出層第j個神經元的權重參數。本文利用toeplitz矩陣實現網絡全連接層的參數共享。toeplitz矩陣簡稱為T型矩陣,矩陣形式如公式(11)所示:
圖3 輸入層與輸出層神經元連接和權重參數矩陣Fig.3 Input layer and output layer neurons connected and weighting parameter matrix
在toeplitz 矩陣中只有2n-1 個獨立的元素。經過toeplitz 矩陣規(guī)劃后的網絡在輸出層與輸入層神經元之間的連接和權重參數矩陣如圖4(a)、(b)所示。
圖4 toeplitz矩陣規(guī)劃后輸入層與輸出層連接和權重矩陣Fig.4 Connection between input layer and output layer and weight matrix after toeplitz matrix planning
利用toeplitz 矩陣完成全連接層的參數共享,網絡反向傳播過程中的梯度求取公式將會發(fā)生變化。全連接層權重參數可表示為行向量ωrow=(ω11,ω12,…,ω1n)和列向量ωcol=(ω11,ω21,…,ωm1),可見行向量與列向量都有權重參數ω11。x為輸入神經元,x=(x1,x2,…,xn)T,y為輸出神經元y=(y1,y2,…,ym)T,E為模型誤差,?E?y為誤差經過反向傳播返回到當前層的值,?E?y和y都是m維列向量,在對權重參數梯度進行求取時,將輸入的n維x擴展為n+m-1維,具體是:當j >n時,xj=0 即x=(x1,x2,…,xn,0,0,0,…,0)T。行向量ωrow的梯度公式如式(12)~(15)所示:
列向量ωcol的梯度公式如式(16)~(19),ω11的梯度由公式(12)求得公式(16):
其中,LSR(w,g)表示將向量ω邏輯右移g位,×代表兩個矩陣相乘,T表示矩陣的轉置。
為證明本文所提特征分布蒸餾算法的有效性,首先在圖像分類任務上對該算法進行了實驗驗證;其次,在目標檢測和語義分割任務上進行了實驗,驗證該算法在不同圖像處理任務上的優(yōu)勢;最后在每個圖像處理任務中分別設置兩組教師網絡和學生網絡的蒸餾實驗,驗證該算法在不同網絡架構間的蒸餾效果。
由于CIFAR-100 是在圖像分類任務上運用最廣泛的數據集,為保證與其他知識蒸餾算法的比較的公平性,在圖像分類任務中選擇CIFAR-100 數據集進行實驗。實驗是由PyTorch實現,最小批次設置為64,epochs設置為500,初始學習率設為0.1,然后分別在第200 輪次和300 輪次時將學習率調整為0.01 和0.001。蒸餾過程中的參數設置,logits蒸餾中軟化系數T為4.0。
CIFAR-100數據集總共有60 000張尺寸為32×32的彩色圖像,包含有100 個類,每一類有600 張圖像,其中500張圖像用于訓練,100張用于測試,分別采用ResNet-152 和ResNet-50 作為教師模型和學生模型進行實驗。首先設計消融實驗分別驗證算法中特征分布蒸餾和參數共享的有效性,其結果如表1 所示,可以看出經過特征分布蒸餾算法訓練后的學生網絡的分類準確率(Accuracy)有了明顯的提升,甚至優(yōu)于教師網絡,在壓縮率(Params)方面,學生網絡的參數量遠遠少于教師網絡。在知識蒸餾的基礎上對模型進行參數共享處理,實驗表明經過參數共享處理后的網絡實現了在精度僅減少0.1%~0.2%(可接受范圍內)的情況下參數量降低至教師網絡的33.87%,節(jié)省了模型對存儲空間的需求。同時將具有代表性的知識蒸餾算法KD、FitNets、AT、AFID、AMIMI-KD、AFD與本文所提的特征分布蒸餾算法在相同的參數和環(huán)境設置下做了對比實驗,實驗結果如表2所示。
表1 消融實驗Table 1 Ablation experiment
表2 CIFAR-100數據集上不同蒸餾算法間的性能對比Table 2 Performance comparison of different distillation algorithms on CIFAR-100 dataset
由表2的實驗結果可以看出,特征分布蒸餾算法在分類任務上的準確率達到了78.92%,相較于全監(jiān)督教師網絡提高了0.78%;相較于全監(jiān)督學生網絡提高了8.89%;相較于KD、FitNets、AT、AFID、AMIMI-KD、AFD算法分別提高了1.58%、6.9%、0.52%、1.92%、2.09%和0.85%。可以看出特征分布蒸餾算法在圖像分類任務上具有很好的性能表現,同時該實驗也證明了利用圖像在教師網絡中的特征分布情況來指導學生網絡學習的效果明顯優(yōu)于利用特征圖來指導學生網絡學習的效果。
上述實驗可以看出特征分布蒸餾算法優(yōu)秀的分類準確率和高壓縮比,顯示了該算法的壓縮潛力。為此,本文采用更大更復雜的數據集ImageNet進一步驗證該算法的收斂速度和魯棒性。ImageNet 數據集包含有120 萬張訓練圖像和5 萬張測試圖像,共分為1 000 個類。與CIFAR-100數據集相比,ImageNet數據集具有更豐富的種類,且圖像規(guī)模更大(平均469×387)。首先以ResNet-152和ResNet-50作為教師模型和學生模型測試學生網絡在經過特征分布蒸餾算法前后的Kappa系數,Kappa 系數是在分類任務中評價分類準確率的一個重要評價指標,其值越高表示分類準確率越高,最高值為1.0。實驗結果如圖5所示,可以看到無論是否經過本文蒸餾算法進行蒸餾,Kappa 系數均呈現上升趨勢,但ResNet-50未經特征分布蒸餾時的訓練數據集Kappa系數始終低于經過特征分布蒸餾后的Kappa 系數。同時未經過特征分布蒸餾的預測訓練集收斂速度也比較慢,且波動幅度也比較大,經特征分布蒸餾后的網絡在預測訓練集上的波動幅度小,并且在迭代次數為400次時便完成了收斂,開始趨于穩(wěn)定。實驗證明,在經過特征分布蒸餾后的網絡加快了圖像分類的收斂速度,提升了圖像分類準確率。
圖5 ResNet-50在特征分布蒸餾前后的Kappa系數Fig.5 Kappa coefficient of ResNet-50 before and after distillation of characteristic distribution
此外,本文采用Top-1 error和Top-5 error作為評價指標驗證該算法在不同網絡結構間的蒸餾能力,第一組仍然選用ResNet-152和ResNet-50作為教師模型和學生模型,第二組選用ResNet-50 作為教師網絡,MobileNet作為學生網絡。為保證實驗的公平性,兩個教師網絡均提前在PyTorch 庫中預先訓練,實驗結果如表3 所示。可以看出,本文的算法在ImageNet 數據集上仍有優(yōu)秀的表現。
表3 ImageNet數據集上的圖像分類性能表現Table 3 Image classification performance on ImageNet dataset
在相同網絡架構間進行蒸餾時,特征分布蒸餾算法的Top-1 error 和Top-5 error 相較于全監(jiān)督教師網絡分別減少了0.6%和4.7%;相較于全監(jiān)督學生網絡分別減少了9.56%和20.45%??梢钥闯鼋浱卣鞣植颊麴s后的網絡性能均優(yōu)于全監(jiān)督教師網絡和全監(jiān)督學生網絡。
在不同的網絡架構間進行蒸餾時,特征分布蒸餾算法的Top-1 error 和Top-5 error 相較于全監(jiān)督教師網絡是增加的,但相較于全監(jiān)督學生網絡而言分別減少了14.45%和20.6%??梢钥闯黾词故懿煌W絡架構的影響,特征分布蒸餾算法在不同網絡架構間的蒸餾效果仍是優(yōu)于全監(jiān)督學生網絡的。
本文將特征分布蒸餾算法應用于目前流行的高速檢測器SSD 上[15],所有模型都是用VOC2007 訓練集進行訓練,其中主干網絡使用ImageNet 數據集進行與訓練,將沒有經過蒸餾訓練的SSD 作為基準,分別以ResNet50 和ResNet18 作為教師網絡和學生網絡,為驗證本文所提算在目標檢測任務上的性能優(yōu)勢,將特征分布算法與AT、AFID、FitNets 和AFD 算法在目標檢測任務上的性能進行了比較,圖6 表示幾種算法分別在IOU=0.5 和IOU=0.7 時的PR 比曲線,從比較結果來看,經過特征分布蒸餾后的網絡在目標檢測任務上的準確率和召回率兩個性能指標均優(yōu)于其他蒸餾算法,這是因為特征分布蒸餾算法中利用條件互信息對樣本的特征分布情況進行建模,使其在蒸餾過程中更加聚焦于樣本的局部特征信息,對邊界框的回歸自然更加準確。
圖6 網絡在經過不同算法蒸餾后的PR曲線Fig.6 PR curve after distillation by different algorithms
同時分別以ResNet18 和MobileNet 為學生網絡,平均精度均值(mAP)作為評價指標進行實驗,比較了幾種不同的算法在目標檢測任務中對不同網絡架構的蒸餾效果。結果如表4所示。
表4 VOC2007測試集上的目標檢測性能比較Table 4 Target detection performance comparison on VOC2007 testset
實驗驗證了本文所提的特征分布蒸餾可以很好地應用于目標檢測任務上,同時該算法在壓縮率方面也有很好的表現,這是在其他對比算法中沒有實現的。
在相同網絡架構間蒸餾時,雖然特征分布蒸餾算法相較于全監(jiān)督教師網絡的平均精度值降低了3.3%,但模型壓縮率達到了教師網絡的44.1%;相較于全監(jiān)督學生網絡而言平均精度值提高了4.2%,同時特在此基礎上壓縮了模型的大小;相較于FitNets、AT、AFID、AFD幾種對比算法的平均精度值分別提高了2.4%、2.1%、1.04%和1.85%。
在不同網絡架構間蒸餾時,特征分布蒸餾算法的平均精度值相較于全監(jiān)督學生網絡提高了3.85%,且優(yōu)于其他幾種對比算法。同時,可以發(fā)現其他幾種對比算法的平均精度值與全監(jiān)督學生網絡相比均是降低的,證明了特征分布蒸餾在不同網絡架構間蒸餾的有效性。
本文對語義分割任務進行了知識蒸餾,實驗中以基于ResNet101 的DeepLabV3+作為教師網絡,并以基于ResNet18和MobileNetV2的DeepLabV3+分別作為學生網絡,以均交并比(mIoU)和像素準確率(PA)作為本文算法在語義分割任務上的性能指標,其實驗結果如表5所示,可以看到本文所提的特征分布蒸餾算法大幅度提高了ResNet18 和MobileNet 模型的性能表現,同時該算法的壓縮效果優(yōu)于另外幾種對比算法。
表5 VOC2012測試集上的語義分割性能比較Table 5 Semantic segmentation performance comparison on VOC2012 testset
在相同網絡架構間進行蒸餾時,特征分布蒸餾算法的均交并比和像素準確率相較于全監(jiān)督教師網絡的分別提高了4.34%和0.5%,相較于全監(jiān)督學生網絡分別提高了14.01%和4.28%;而其他特征蒸餾算法的均交并比和像素準確率相較于教師網絡均是減小的,相較于全監(jiān)督學生網絡,兩種指標的增長幅度都很小。這是因為其他特征蒸餾算法中,采用教師網絡的特征圖指導學生網絡的學習,這種方法只傳遞給學生網絡一個特征提取的結果,并不能很好地鍛煉學生網絡的特征擬合能力。而本文的特征分布蒸餾算法指導學生網絡擬合教師網絡的特征分布情況,很好地鍛煉了學生網絡的特征擬合能力從而提高了學生網絡的特征提取能力,使學生網絡能更好地應用于不同的圖像處理任務中。
在不同的網絡架構間進行蒸餾時,特征分布蒸餾算法的均交并比和像素準確率相較于全監(jiān)督教師網絡是減少的,但是相較于全監(jiān)督學生網絡及其他幾種對比蒸餾算法而言是提高的,說明特征分布蒸餾算法雖然突破了不同網絡架構間蒸餾的限制,但還是受其影響的。
本文提出的特征分布知識蒸餾通過MMD 匹配教師網絡與學生網絡之間的特征分布情況,不需要對任何超參數進行特定的調整,同時通過利用條件互信息對數據樣本的建模,使學生網絡擬合教師網絡的特征分布,更好地訓練了學生網絡的特征擬合能力,使其可以很好地應用于多種圖像處理任務中,并且可以實現在不同維度空間內傳遞教師網絡的知識。其次在特征蒸餾的基礎上結合參數共享算法實現對學生網絡的壓縮,進一步節(jié)省了模型的資源需求。本文所提的特征分布蒸餾算法在圖像分類、目標檢測和語義分割任務場景下的性能和壓縮率均優(yōu)于其他幾種對比算法。后續(xù)研究中會進一步研究不同網絡架構間的蒸餾,提高不同網絡架構間的蒸餾效果。