曹炅宣 常 明 張 蕊③** 支 天** 張曦珊**
(*中國科學技術大學 合肥 230026)
(**中國科學院計算技術研究所 北京 100190)
(***中科寒武紀科技股份有限公司 北京 100191)
近年來,得益于越來越深的網絡層數和越來越大的參數量,深度神經網絡(deep neural networks,DNN)在各類任務中取得了顯著的成功。然而,DNN有著較高的計算復雜度和較大的參數儲存要求,因此將其部署到運算資源有限的設備或者對即時性要求較高的應用場景變得比較困難,例如智能手機、嵌入式設備和邊緣計算。因此壓縮大的模型并提高其運行速度變得非常重要。知識蒸餾(distilling the knowledge,KD)就是一種十分有效的模型壓縮方法,其通過從大型教師網絡中提取有用的知識轉移給小型學生網絡從而提高小型網絡的性能。知識蒸餾的損失函數包含2 個部分,一個是來自于真實標簽的任務損失,另一個部分則是來自于教師網絡的蒸餾損失。
因此,如何有效地找到2 個損失函數的權重成為了一個待解決的問題。換言之,如何在訓練過程中更合理地混合這2 個損失的梯度?,F在大多數已有的知識蒸餾方法都是手動調整損失權重,這種方法既繁瑣又十分浪費計算資源,并且往往無法達到最佳性能。手動搜索權重的問題主要在于權重的搜索空間范圍特別大,而且往往是連續(xù)的。例如,根據框架RepDistiller[1],在相同的數據集和師生網絡組合下,蒸餾損失權重在0.02(基于概率的知識轉移方法(probabilistic knowledge transfer,PKT)[2]) 到30 000(計算相關性的知識蒸餾方法(correlation congruence for knowledge,CC)[3])之間變化。
針對這個問題,可以采用超參數優(yōu)化(hyperparameter optimization,HPO[4])和多任務學習(multitask learning,MTL)來確定2 個損失的權重,但將這2 種方法應用到知識蒸餾的訓練時存在著一些缺陷。在知識蒸餾訓練中,存在2 個優(yōu)化目標:用于任務損失的真實標簽以及用于蒸餾損失的教師網絡。而知識蒸餾設計的初衷就是使用蒸餾損失作為輔助,幫助作為主要目標的任務損失降低到最小。但超參數優(yōu)化和多任務學習方法會認為這2 個損失處于一個平等情況,因此會產生大量冗余的搜索空間導致參數調節(jié)過程的效率十分低下,并且過分平衡用于輔助的蒸餾損失也有一定可能損害到主優(yōu)化目標(即任務損失),后續(xù)多任務學習的實驗也證明了這一點。
為了解決上述問題,本文提出了一種新穎的自動梯度混合方法,該方法可以自動地為知識蒸餾訓練找到合適的損失函數權重。本文將尋找合適的損失函數權重的問題轉換為尋找2 個損失通過反向傳播得到的最佳混合梯度的問題??紤]到在知識蒸餾中,蒸餾損失是任務損失的輔助這一重要的先驗知識,自動梯度混合方法可以顯著減少混合梯度的搜索空間。通過找到混合梯度的模長和方向從而確定用于更新模型參數的混合梯度。在具體訓練過程中,混合梯度的模長用來控制模型參數更新速度,而方向則是決定著模型最終的訓練結果。因此自動梯度混合方法通過固定混合梯度模長與任務損失產生的梯度模長相同,用來保證模型迭代的穩(wěn)定性。在只需要搜索方向的情況下,可以有效地減少混合梯度的搜索空間并提高搜索效率。在確定了混合梯度的模長和方向后,就可以計算出2 個損失函數的權重,從而避免了復雜的手動調節(jié)過程。
與現有的手動調節(jié)方法相比,本文提出的自動梯度混合方法有效利用了知識蒸餾的先驗知識,具有以下幾個優(yōu)點:首先,自動梯度混合方法將混合梯度的模長約束到與任務梯度模長相同,這樣能夠保證模型訓練的收斂穩(wěn)定性,解耦了梯度向量模長和方向,只需要在方向上進行搜索,顯著減少了搜索空間;此外,在進行了該梯度模長的約束后,早期訓練輪次的結果與最終訓練輪次的結果具備一個較好的保序性,從而通過一個極短時間的預訓練即可找到較優(yōu)的混合方向,從而實現了比手動設置權重更好的性能;最后,自動梯度混合方法是一種簡單易用的方法,能夠適用于絕大部分的知識蒸餾方法,可以對某種蒸餾方法在某類應用場景下是否有效進行一個快速驗證。
為了證明自動梯度混合方法的效果,本文在CIFAR-100[5]和ImageNet-1k[6]數據集上使用Rep-Disitiller[1]框架進行實驗,自動梯度混合方法在130個組別中表現超過70%的手動調節(jié)結果。在時間上,與超參數優(yōu)化方法相比,自動梯度混合方法只需要1/10 或者更少的時間就能達到與超參數優(yōu)化方法相當的精度。
知識蒸餾將大的、笨重的教師網絡的知識轉移給更小、更敏捷的學生網絡中,從而能夠有效提高學生網絡的性能。Hinton 等人[7]提出了這種方法,該方法使用溫度來修正教師網絡輸出的softmax,使其作為軟標簽來指導小型的學生網絡。目前有3 種不同類型的知識蒸餾,分別是基于響應、基于特征和基于關系的知識蒸餾方法[8]?;陧憫姆椒╗7]旨在通過使用教師網絡的logits 作為知識來直接模擬教師網絡的最終預測。基于特征的方法[9-12]則是專注于匹配教師網絡和學生網絡中間層的特征。基于關系的方法[1,3,13-14]認為不同層或數據樣本間的關系能有助于蒸餾。然而,現有絕大部分方法都使用手動調整來找到合適的任務損失權重和蒸餾損失權重,這既繁瑣又十分耗時,而且往往無法達到最佳性能。
超參數優(yōu)化(HPO)方法是一類尋找最優(yōu)的超參數組合的方法。這些方法可以分成3 類。第1 類是窮舉搜索,例如隨機搜索和網格搜索。網格搜索將超參數空間劃分為不同的網格并運行每個網格對應的參數組合以此找到最佳參數。這種遍歷式的搜索方法由于沒有對搜索空間進行任何裁剪,因此非常耗時。為了使得搜索過程效率更高,研究人員提出了第2 類啟發(fā)式搜索方法,該類搜索方法可以在搜索過程中根據可用信息(例如之前訓練的結果)選擇后續(xù)最佳的搜索分支。超參數優(yōu)化方法中包含有一些經典的啟發(fā)式搜索方法,例如樸素進化和模擬退火。最近,研究人員也提出了Hyperband[15]、Popluation-Based Training[16]等新的啟發(fā)式方法。第3 類是貝葉斯優(yōu)化,它通過條件概率建模來預測給定超參數的最終性能,例如序列貝葉斯優(yōu)化(sequential Bayesian optimization hyperband,BOHB)[17]、樹形Parzen 估計方法(tree-structured Parzen estimator approach,TPE)[18]等。與手動設置參數相比,超參數優(yōu)化方法的調節(jié)器理論上可以節(jié)省一些搜索時間,但是仍然非常耗時。
多任務學習(multi-task learning,MTL)是指通過使用所有任務和其他一些任務中包含的知識來共同學習多個任務,以此來提高每個任務性能的一種訓練方法。多任務學習方法包括2 個方面[19]。一些多任務學習方法設計深度學習多任務架構,包含有設計側重于編碼器[20-21]或側重于解碼器[22-23]的架構。其他的一些多任務學習方法則是側重于平衡多個任務的訓練優(yōu)化,例如Uncertainly[24]、GradNorm[25]、DWA[26]、DTP[27]、Multi-Objective Optim[28]。絕大部分多任務學習方法都會等權重優(yōu)化所有任務或者是所有損失函數,因此它可能會和知識蒸餾中將任務損失視為主要損失、將蒸餾損失視為輔助的理念相沖突。
為了提高小型學生網絡的性能,知識蒸餾除了利用來自于真實數據的監(jiān)督外,還額外引入了來自于大型的教師網絡中的有益的知識。因此,總的損失函數由來自于真實標簽的任務損失和來自于教師網絡的蒸餾損失構成,公式為
這里Lkd是總的知識蒸餾損失函數,Ltask是任務損失,Ldistill是蒸餾損失。α和β是任務損失和蒸餾損失的縮放系數。為了獲得合適的系數α和β,絕大部分已有的知識蒸餾方法都是通過手動調節(jié)方法來進行搜索,這類方法非常繁瑣又耗時,并且往往無法使學生網絡擁有最佳的性能。為了解決這個問題,本文提出了一種自動梯度混合方法來自動高效地找到損失權重。
假設在整個訓練過程中,第t輪的模型參數更新迭代時,損失函數對模型參數求導后得到的梯度被用來迭代模型參數,公式為
為了有效搜索最優(yōu)混合梯度,需要盡可能地縮小搜索空間。在這項工作中,本文利用了知識蒸餾中的一個重要的先驗知識,即任務損失是主要優(yōu)化目標,而蒸餾損失是任務損失的輔助。因此,混合梯度Gkd應當與任務梯度Gtask更加相關,蒸餾梯度Gdistill用來做一個細化調整。本文通過確定混合梯度Gkd的方向和模長來找到這個混合梯度。一般而言,在使用梯度來更新模型參數的過程中,梯度向量具有2 個自變量,一個是方向,另一個則是模長,兩者的功能具有一定差異。梯度的模長主要影響著模型參數的更新速度,從而控制模型收斂,當模長太長時,會出現梯度爆炸使得模型無法收斂或者是在最優(yōu)值附近徘徊的情況;而模長過短時,模型收斂會非常緩慢,找到最優(yōu)值的時間過長,也有可能陷入到某個局部最優(yōu)點中。梯度的方向則是決定著模型參數的更新方向,決定模型最終的收斂位置能否在相應的指標上取得好的效果(如分類任務中的準確率,檢測任務中的mAP 等)。在非蒸餾訓練中,模型僅使用任務損失產生的梯度就能訓練出來一個穩(wěn)定的結果。本文基于上述先驗知識,為了提高效率減小搜索空間,以及保證模型訓練的收斂穩(wěn)定性,自動梯度混合方法將混合梯度的模長約束到與任務損失梯度模長相同,公式為
在實現該約束后,可以很方便地將學生網絡的非蒸餾訓練版本的超參數,如學習率、權重衰減等,方便應用到本文中使用的蒸餾訓練上。因此可以通過對Gkd的模長約束得到一個穩(wěn)定的訓練過程。
在確定了模長大小后,自動梯度混合方法只需要在搜索空間中搜索Gkd梯度方向,該梯度方向由任務梯度Gtask和蒸餾梯度Gdistill決定。如圖1 所示,Gkd方向的搜索空間為Gtask和Gdistill之間的角度空間。θ為Gtask和Gdistill夾角大小:
圖1 梯度混合示意圖
假設Gtask和Gkd的夾角為λθ,只需要在λ∈[0,1] 這個范圍內進行搜索。在這種方式下,由于不需要對Gkd的模長進行搜索,整個搜索空間得到大幅度縮減,同時對最優(yōu)方向的搜索可以保證混合梯度Gkd的有效性。
通過搜索得到λ后,可以用如下公式表示Gkd的方向:
使用式(3)和(4),可以得到:
聯(lián)立式(5)~(7),可以解得損失權重系數α和β為
如式(9)所示,損失權重系數α和β取決于λ。λ的有效值為[0,1]。當λ等于0 時,蒸餾損失對混合梯度沒有任何影響;當λ等于1 的時候,混合梯度方向會完全遵循蒸餾梯度的方向。此外,實驗結果表明自動梯度混合方法在訓練早期和后期的性能(在分類任務中為準確率)有著良好的保序性。因此,為了進一步提高搜索過程中實驗的效率,本文使用訓練早期的訓練效果來預測最終的性能。在具體操作中,本文在搜索空間中對λ進行一個早期的搜索來作為預熱訓練模型。然后選擇性能最佳的一個作為λ的最佳值。之后可以采用式(9)來計算損失權重α和β,并且使用它們來完成訓練。搜索和訓練模型的整個過程如算法1 所示。
本節(jié)中,本文將提出的自動梯度混合方法應用在被廣泛使用的圖像分類數據集CIFAR-100[5]和ImagNet LSVRC 2012[6]上。此外,本文使用的Rep-Distiller[1]框架基于Pytorch,其模型庫中包含有13種流行的蒸餾方法。在實驗中,本文遵循RepDistiller 默認的超參數設置,如訓練輪次、學習率、優(yōu)化器等。在自動梯度混合方法中,預熱輪次設置為5。作為對比實驗,本文使用RepDistiller 中給出的手動調整的損失權重的訓練結果作為基線。
本文在KD[7]、Fitnets[11]、SP[29]、AT[12]、CC[3]、VID[29]、RKD[13]、PKT[3]、FT[10]和NST[9]這10 種蒸餾方法上進行實驗。此外,實驗還包含有7 個相似架構的師生網絡組合和6 個不同架構的師生網絡架構,即整個實驗包含有10 ×13 個小的實驗。
結果如表1 所示,可以發(fā)現自動梯度混合方法和手動方法比較,無論是在教師網絡架構和學生網絡架構相似的VGG13-VGG8 和ResNet110-ResNet32亦或者是ResNet32x4-ShuffleNetV2 和VGG13-MobileNetV2 這類架構差異很大的網絡上都有比較好的效果??偨Y表1 的結果可以發(fā)現,自動梯度混合方法在70%的蒸餾組合上都要比手動調節(jié)的方法表現得更好。
表1 在數據集CIFAR-100 上使用手動調節(jié)(Manual)和 自動梯度混合方法(AGB)在10 種不同的蒸餾方法和13 種不同的師生網絡組合的Top-1 準確率(%)
本文使用KD、CC、對比表示知識蒸餾方法(contrastive representation distillation,CRD)和注意知識蒸餾方法(attention on distillation,AT)在Image Net-1K數據集進行實驗。因為RepDistiller 框架沒有ImageNet-1K 對應代碼,所以本文在ImageNet-1K 上復現了這4 種方法。超參數和手動調整的損失權重是按照另一個蒸餾框架TorchDistil 設置的。本文使用Pytorch 團隊發(fā)布的模型ResNet34 和ResNet18 作為教師和學生網絡,并遵循TorchDistill 的ImageNet 訓練設置。
表2 展示了自動梯度混合方法和手動參數設置方法在以ResNet34 和ResNet18 作為師生網絡組合上的top-1 準確度。對于KD、CC 和AT 方法,自適應梯度混合方法可以獲得更好的性能,對于CRD 方法,自動梯度混合方法也可以達到和手動設置接近的性能。因此,ImageNet-1K 上的實驗有效證明了自動梯度混合方法的有效性。
表2 自動梯度混合方法(AGB)和手動調整(Manual)在ImageNet-1k 上的Top-1 準確度(%),其中教師網絡是ResNet34(top-1 準確度73.314%),學生網絡是ResNet18(top-1 準確度69.76%)
本文在CIFAR100 上使用自動梯度混合方法和Microsoft Neural Network Intelligence (NNI)的3 個不同的超參數優(yōu)化調節(jié)器進行了對比。這些超參數優(yōu)化方法包括有啟發(fā)式搜索方法模擬退火(simulated annealing)、Hyperband[15]和貝葉斯優(yōu)化方法TPE[18]。選擇VGG13 和VGG8 作為師生網絡,并使用AT 蒸餾方法進行實驗,在超參數優(yōu)化方法中,參照式(1),設置α等于1,β的搜索空間為0.02 到30 000。
圖2 顯示了3 個超參數優(yōu)化調節(jié)器和自動梯度混合方法的比較實驗??梢杂^察到自動梯度混合方法只需要極少訓練的時間就能達到非常高的精度。相比之下,在運行同樣的時間中,超參數優(yōu)化方法只能實現更低的精度。盡管超參數優(yōu)化方法在最終的結果中達到了與自動梯度混合方法相當或者略高的精度,但它們需要更多的時間來進行搜索,這是非常低效的。
分析超參數優(yōu)化方法出現的問題,可以發(fā)現無論是手動調節(jié)、超參數優(yōu)化或者是一些簡單約束情況,都會導致超參數搜索過程變得漫長而復雜。本質上,這是由于這類方法在搜索超參數時會將總梯度向量模長和方向進行耦合,同時去搜索梯度向量的方向和模長,會影響模型的收斂性,并出現兩類冗余搜索的情況:(1)搜索到合適的方向而模長過長或過短,導致出現模型無法收斂;(2)搜索到合適的模長而方向不對,這樣會影響模型最終的收斂位置,即影響模型最終的結果。而當一些更為奇怪的約束使得總梯度向量的方向與模長耦合得更加緊密時,甚至無法搜索到對應合適方向。
將Uncertainly 和GradNorm 這2 種無超參數的多任務學習方法與自動梯度混合方法進行對比實驗。本文對所有的10 種蒸餾方法進行了實驗,所有的13 種教師學生網絡組合與3.1 節(jié)中的相同。
如表3 所示,自動梯度混合方法應用到絕大多數蒸餾方法中都優(yōu)于這2 種多任務學習方法。多任務學習方法將蒸餾損失和任務損失平等對待,忽略了知識蒸餾的重要先驗知識,即任務損失是起到主導作用的,而蒸餾損失是用于輔助的。因此,多任務學習方法可能會為了最大限度地降低蒸餾損失而犧牲了性能。還可以發(fā)現,當使用GradNorm 時,大多數蒸餾方法的性能都很差。這是因為GradNorm 完全忽略了任務損失應該為主導地位。而且,與任務損失相比,蒸餾損失通常非常大或者非常小。例如,在CC 中,蒸餾梯度的模長約為任務梯度的100 倍,
表3 多任務學習方法GradNorm 和Uncertainly 在CIFAR-100 上與自動梯度混合方法(AGB)相比的Top-1 測試準確度(%)。由于訓練過程中的梯度爆炸,一些方法顯示出非常差的準確性或無法訓練出有效的結果(用表示)。null 表示此蒸餾方法不支持多任務學習方法。
而在PKT 中,蒸餾梯度的模長約為任務梯度的0.001倍。因此,GradNorm 簡單地平衡2 個損失將會導致整個訓練過程不穩(wěn)定。相比之下,自動梯度混合方法將混合梯度的模長限制為與任務梯度的模長相同。因此,自動梯度混合方法在獲得穩(wěn)定訓練過程的同時,可以保留任務梯度占據主導地位這一重要信息。
本文驗證了在自動梯度混合方法中訓練早期和訓練后期準確率的保序性。在CIFAR-100 上使用AT 蒸餾方法進行這些實驗,在NNI 上用VGG13 作為教師網絡,VGG8 作為學生網絡。計算早期(第5輪)的準確率和整個訓練結束的最終準確率之間的相關系數。本文還對手動調節(jié)方法進行了這些實驗,α設置為1,β從0.003 變化到30 000。為了公平地比較,本文選擇結果接近收斂時的最后80 次實驗來驗證相關性。
如圖3所示,下圖為自動梯度混合方法,其相關系數為0.724,遠高于上圖中手動調節(jié)方法的0.410。這個實驗說明了使用自動梯度混合方法時早期輪次表現較好的設置同樣可以運用到晚期輪次。因此,預熱策略可以在不損失性能的前提下大幅提升自動梯度混合方法的效率。
圖3 最佳精度與早期輪次精度之間的相關性
本文就預熱階段設置的熱身輪次和預熱階段用于離散化的步長進行了消融實驗。在CIFAR-100上使用KD 蒸餾方法進行實驗,教師網絡為Res-Net32x4,學生網絡為ResNet8x4。
圖4 顯示了準確率、時間開銷與步長的關系??梢钥吹?當步長從0.2 變小后,時間開銷增大,對應的結果略有上升;而當步長變大后,實際上的節(jié)省的時間相當有限,而性能也會出現一定程度的下降。圖5 則顯示了準確率、時間和熱身輪次的關系。可以發(fā)現,與前面步長類似,選取更小的熱身輪次并不會導致運行時間有一個顯著的變小。而當熱身輪次提升后,時間開銷增大了,對于實驗的準確率也沒有提升太多。因此本文取的熱身輪次和步長并不具備特殊性,取附近的幾個值結果差異不會太大,這也說明了是前面模長約束在方法中起到了主要的作用而預熱的等間距選取最優(yōu)的策略只是用于輔助的。
圖4 準確率、時間與步長之間的關系
圖5 準確率、時間與熱身輪次之間的關系
圖2 中的結果也顯示了自動梯度混合方法的高效性。圖2 中圓點表示一次超參數優(yōu)化方法實驗的準確性。隨著訓練實驗的增加,每條虛線表示 超參數優(yōu)化方法的最佳準確性。三角形標記表示自動梯度混合方法的結果,該方法需要大約1.50 次實驗時間才能達到72.48%的準確率。在知識蒸餾中尋找損失權重時,手動調整會受到大的搜索空間的影響。通過使用貝葉斯優(yōu)化或者是其他算法改進搜索過程,超參數優(yōu)化方法會高效一些,但是仍然有著比較大的搜索空間。相比之下,自動梯度混合方法通過約束混合梯度的模長并僅僅在預熱階段在方向上進行搜索,從而顯著減少了搜索空間。如圖2 所示,超參數優(yōu)化方法需要10 次以上的實驗才能達到與自動梯度混合方法相當的精度。因此,與超參數優(yōu)化方法相比,自動梯度混合方法效率更高。
本文提出了一種自動梯度混合方法,可以有效地為絕大部分知識蒸餾方法找到合適的損失權重。利用蒸餾損失是用于輔助任務損失這一先驗知識,自動梯度混合方法通過減少超參數搜索空間來優(yōu)化搜索過程。自動梯度混合方法只搜索梯度方向,即2 個損失梯度之間的角度,同時將混合梯度的模長約束為與任務損失梯度模長相同。本文在13 種不同的師生網絡組合之間對10 種不同的知識蒸餾方法進行了實驗。自動梯度混合方法在使用更少的運算資源的前提下在70%的蒸餾方法上性能超過了手動調節(jié)方法,這說明自動梯度混合方法具有更好的效果以及更高的效率。本文工作的前提是假設當有多個蒸餾損失時,所有的蒸餾損失共享相同的權重。未來,可以將本文工作擴展到具有多種蒸餾損失的情況。