劉小雷,高凱新,王 勇
(天津大學(xué) 數(shù)學(xué)學(xué)院,天津 300350)
近些年來,深度學(xué)習(xí)已經(jīng)在計(jì)算機(jī)視覺和自然語言處理等領(lǐng)域取得了重要的進(jìn)展.然而隨著研究的深入,模型越來越復(fù)雜,往往需要耗費(fèi)大量的訓(xùn)練時(shí)間和計(jì)算成本.因此,采用有效的訓(xùn)練方法是十分有必要的.以隨機(jī)梯度下降(SGD)為代表的一階優(yōu)化方法是當(dāng)前深度學(xué)習(xí)中最常用的方法.近些年來,一系列SGD的改進(jìn)算法被提出并也被廣泛于應(yīng)用深度學(xué)習(xí)中,比如,動(dòng)量SGD (SGDM[1]),Adagrad[2],Adam[3].這些一階優(yōu)化方法具有更新速度快,計(jì)算成本低等優(yōu)點(diǎn),但是也具有收斂速度慢,需要進(jìn)行復(fù)雜調(diào)參等缺點(diǎn).
通過曲率矩陣修正一階梯度,二階優(yōu)化方法可以得到更為有效的下降方向,使得收斂速度大大加快,減少了迭代次數(shù)和訓(xùn)練時(shí)間.對(duì)于有著上百萬甚至更多參數(shù)的深度神經(jīng)網(wǎng)絡(luò)而言,其曲率矩陣的規(guī)模是十分巨大的,這樣大規(guī)模的矩陣的計(jì)算,存儲(chǔ)和求逆在實(shí)際計(jì)算中是難以實(shí)現(xiàn)的.因此,對(duì)曲率矩陣的近似引起了廣泛的研究.其中最基本的方法是對(duì)角近似,其在實(shí)際計(jì)算中取得了較好的效果,但是在近似過程中丟失了很多曲率矩陣的信息,而且忽略了參數(shù)之間的相關(guān)性.在對(duì)角近似的基礎(chǔ)上,一些更為精確的算法也被提出,這些算法不再局限于曲率矩陣的對(duì)角元素,同時(shí)也考慮了非對(duì)角元素的影響.這些方法對(duì)曲率矩陣的研究都取得的一定了進(jìn)展[4–8].但是如何在深度學(xué)習(xí)中更加有效地利用曲率矩陣得到更有效的算法,仍然是應(yīng)用二階優(yōu)化方法面臨的重要挑戰(zhàn).
自然梯度下降可以被視為一種二階優(yōu)化方法,其中自然梯度定義為梯度與模型的Fisher 信息矩陣的乘積.該方法最初在文獻(xiàn)[9]中被提出,其在深度學(xué)習(xí)中有著重要的應(yīng)用.文獻(xiàn)[10]中提出了一種近似全連接神經(jīng)網(wǎng)絡(luò)中自然梯度的有效方法,稱之為K-FAC.K-FAC算法首先通過假設(shè)神經(jīng)網(wǎng)絡(luò)各層之間的數(shù)據(jù)是獨(dú)立的,將Fisher 信息矩陣近似為塊對(duì)角矩陣;將每個(gè)塊矩陣近似為兩個(gè)更小規(guī)模矩陣的克羅內(nèi)克乘積,通過克羅內(nèi)克乘積的性質(zhì)可以有效計(jì)算近似后的Fisher 信息矩陣及其逆矩陣.K-FAC有效地減少了自然梯度下降的計(jì)算量并取得了很好的實(shí)驗(yàn)效果.這一方法也被應(yīng)用到其他的神經(jīng)網(wǎng)絡(luò)中,包括卷積神經(jīng)網(wǎng)絡(luò)[11–13],循環(huán)神經(jīng)網(wǎng)絡(luò)[14],變分貝葉斯神經(jīng)網(wǎng)絡(luò)[15,16].通過設(shè)計(jì)K-FAC的并行計(jì)算框架,其在大規(guī)模問題中的有效性也得到了驗(yàn)證[17,18].
K-FAC 算法在眾多問題中都有著很好的表現(xiàn),在保持K-FAC 算法有效性的前提下,進(jìn)一步降低計(jì)算成本和減少計(jì)算時(shí)間是非常值得研究的問題.在本文中,我們基于K-FAC 算法的近似思想,結(jié)合擬牛頓法的思想,提出了一種校正Fisher 信息矩陣的有效方法.該方法的主要思想是先用K-FAC 方法進(jìn)行若干次迭代,保存逆矩陣的信息;在后續(xù)迭代中利用該逆矩陣以及新的迭代中產(chǎn)生的信息,結(jié)合Sherman-Morrison 公式進(jìn)行求逆計(jì)算,大大減少了迭代時(shí)間.實(shí)驗(yàn)中,改進(jìn)的KFAC 算法比K-FAC 算法有相同甚至更好的訓(xùn)練效果,同時(shí)大大減少了計(jì)算時(shí)間.
神經(jīng)網(wǎng)絡(luò)的訓(xùn)練目標(biāo)是獲得合適的參數(shù)θ 來最小化目標(biāo)函數(shù)h(θ),給定損失函數(shù):
其中,x,y分別是訓(xùn)練輸入和標(biāo)簽,θ是模型的參數(shù)向量,p(y|x,θ)表示預(yù)測分布的密度函數(shù).那么Fisher 信息矩陣的定義如下:
文獻(xiàn)[10]中給出了自然梯度的定義:F?1?θh.在實(shí)際計(jì)算中面臨的主要挑戰(zhàn)是計(jì)算自然梯度,也就是計(jì)算逆矩陣F?1及其與?θh的乘積.在深度學(xué)習(xí)中,由于F的規(guī)模太大,直接計(jì)算F?1是不切實(shí)際的.K-FAC 提供了一種近似F?1的有效方法,其近似過程可以分為兩個(gè)步驟.首先,K-FAC 將矩陣F按網(wǎng)絡(luò)的各層分割成塊矩陣,通過假設(shè)不同層之間數(shù)據(jù)是獨(dú)立的,將F近似為塊對(duì)角矩陣;其次將每個(gè)塊矩陣近似為兩個(gè)更小規(guī)模矩陣的克羅內(nèi)克乘積,根據(jù)克羅內(nèi)克乘積求逆及運(yùn)算的相關(guān)性質(zhì),大大減少了計(jì)算量.
考慮一個(gè)有L層的神經(jīng)網(wǎng)絡(luò),al?1,sl分別表示第l層的輸入和輸出,那么sl=Wlal?1,其中Wl為第l層的權(quán)重矩陣,l∈{1,2,···,L}.方便起見,定義如下的符號(hào):
那么Fisher 信息矩陣可以被表示為:
即為:
因此F可以被看作一個(gè)L×L的塊矩陣.定義Fij=E[vec(Dθi)vec(Dθi)T],i,j∈{1,2,···,L}.通過假設(shè)神經(jīng)網(wǎng)絡(luò)各層之間數(shù)據(jù)的獨(dú)立性,也就是Fij=0(i≠j),那么F可以被近似為:
然而由于每個(gè)塊矩陣Fll的規(guī)模仍然很大,因此需要進(jìn)一步近似.每個(gè)塊矩陣Fll可以被寫作:
因此Fisher 信息矩陣的逆矩陣可以被近似為:
為了保持訓(xùn)練的穩(wěn)定性,需要對(duì)克羅內(nèi)克因子Al?1,l?1和Gll添加阻尼如下:
其中,λ是阻尼參數(shù),πl(wèi)理論上可以是一個(gè)任意的正數(shù),但在實(shí)驗(yàn)中發(fā)現(xiàn)根據(jù)以下公式計(jì)算的πl(wèi)是一個(gè)更好的選擇.
其中,dl?1和dl分別是矩陣Al?1,l?1和Gll的維度.因此由上述內(nèi)容可以得到訓(xùn)練中第l層參數(shù)的更新規(guī)則如下:
其中,η是學(xué)習(xí)率,m表示迭代次數(shù).
圖1是K-FAC 算法近似過程示意.
圖1 K-FAC 算法近似過程示意
在實(shí)際計(jì)算中,Fisher 信息矩陣F的主對(duì)角線上的每個(gè)塊矩陣的規(guī)模仍然很大,直接對(duì)矩陣{F11,F22,···,FLL}進(jìn)行求逆的計(jì)算成本很高.K-FAC 算法采用的方法是將Fisher 信息矩陣近似為兩個(gè)小規(guī)模矩陣的克羅內(nèi)克乘積,從而將求解大規(guī)模矩陣的逆轉(zhuǎn)化為求解兩個(gè)小規(guī)模矩陣的逆,大大降低了求逆的成本.在本文中我們結(jié)合K-FAC 方法采取的這種近似方法,給出了Fisher信息矩陣F的另一種近似,其求逆的費(fèi)用更低.我們基于下面的Sherman-Morrison 公式來改進(jìn)K-FAC 算法.
Sherman-Morriso 公式:假設(shè)X∈Rm×m是可逆矩陣,p,q∈Rm為任意列向量,則可逆當(dāng)且僅當(dāng)1+qTX?1p≠0,而且如果X+pqT可逆,逆矩陣可以由以下公式得到:
由上述公式可以看出,如果矩陣X的逆已知(或者很容易求),那么利用Sherman-Morrison 公式可以將矩陣的求逆運(yùn)算轉(zhuǎn)化為矩陣向量乘積,從而可以減少大量的計(jì)算時(shí)間.因此我們根據(jù)Sherman-Morrison 公式,結(jié)合擬牛頓的算法思想,提出了一種K-FAC 算法的改進(jìn)算法.在實(shí)際計(jì)算中,每次迭代均更新逆矩陣需要很高的計(jì)算成本,因此實(shí)驗(yàn)中一般設(shè)置若干次迭代更新一次逆矩陣.在下文中,我們用k表示逆矩陣更新的次數(shù).我們提出的算法主要是對(duì)K-FAC 算法的求逆運(yùn)算進(jìn)行了進(jìn)一步的改進(jìn).其主要思想是用K-FAC 算法先進(jìn)行k次求逆運(yùn)算,保存第k次求逆得到的逆矩陣信息;后續(xù)迭代中利用該逆矩陣的信息以及在新的迭代中產(chǎn)生的信息,結(jié)合Sherman-Morrison 公式進(jìn)行求逆運(yùn)算.下面我們以矩陣A00,G11為例說明改進(jìn)的方法,對(duì)于矩陣All(l∈{1,2,···,L?1})和Gll(l∈{2,3,···,L})均采用相同的近似方法.為方便表示,我們?cè)谙挛闹惺÷韵聵?biāo).
(1)首先,按照K-FAC 算法進(jìn)行k次求逆運(yùn)算,在求出逆矩陣(A(k))?1和(G(k))?1后保留逆矩陣的信息.
(2)其次,矩陣A(k+1)和G(k+1)表示的是在第k+1 次更新逆矩陣時(shí)計(jì)算得到的矩陣,利用矩陣A(k+1)和G(k+1)構(gòu)造向量u(k+1)和v(k+1).從現(xiàn)有文獻(xiàn)中可以看出,矩陣A(k+1)和G(k+1)的主對(duì)角線上的元素占主導(dǎo)性的信息,因此我們選取其主對(duì)角線上的元素來構(gòu)造向量u(k+1)和v(k+1).我們選取:
那么u(k+1)u(k+1)T和v(k+1)v(k+1)T就是保留矩陣A(k+1)和G(k+1)主對(duì)角線元素的對(duì)角矩陣.
(3)最后,利用Sherman-Morrison 公式可以得到:
其中,α,β是兩個(gè)合適的正數(shù).
在改進(jìn)的K-FAC 算法中,主要是結(jié)合Sherman-Morriso 公式對(duì)Fisher 信息矩陣進(jìn)行了近似,因此改進(jìn)的算法在計(jì)算矩陣A,G及其逆矩陣部分與K-FAC 算法有所區(qū)別,其余部分與K-FAC 算法相同.文獻(xiàn)[1]中對(duì)K-FAC 算法整體計(jì)算復(fù)雜度進(jìn)行了詳細(xì)分析,因此在表1中,我們主要給出了兩種方法在計(jì)算矩陣A,G及其逆矩陣的計(jì)算復(fù)雜度對(duì)比.
表1 K-FAC和改進(jìn)的K-FAC 算法的求逆計(jì)算復(fù)雜度對(duì)比
對(duì)于改進(jìn)的K-FAC 算法,前k次與K-FAC 算法一致.在實(shí)際計(jì)算中,k的取值遠(yuǎn)遠(yuǎn)小于t.比如在實(shí)驗(yàn)中我們選擇k=10,t=39 100.因此前t次的計(jì)算成本占比很低.在后續(xù)的更新過程中,改進(jìn)的K-FAC 算法一方面矩陣求逆部分都轉(zhuǎn)化成了矩陣乘法,計(jì)算復(fù)雜度由O(n3)降為O (n2);另一方面,在后續(xù)迭代中,我們僅需要計(jì)算矩陣A和G的主對(duì)角線元素,可以直接將向量元素對(duì)應(yīng)相乘,不再需要進(jìn)行矩陣乘法,計(jì)算復(fù)雜度由O(n2)降為O(n).因此,通過上述兩個(gè)方面改進(jìn)的KFAC 算法可以減少大量的計(jì)算時(shí)間.算法1 總結(jié)了改進(jìn)的K-FAC 算法的流程.
算法1.改進(jìn)的K-FAC 算法η λ TFIM,Tinv k輸入:訓(xùn)練集T,學(xué)習(xí)率,阻尼參數(shù),Fisher 信息矩陣及其逆矩陣的更新頻率,逆矩陣更新次數(shù)θ輸出:模型參數(shù)All(l∈{0,1,···,L?1})Gll(l∈{1,2,···,L})m=0,t=0初始化參數(shù)和,;While 未達(dá)到終止條件do m≡0(mod TFIM)if then t<=k if then{All}L?1 l=0,{Gll}Ll=1根據(jù)式(2)計(jì)算因子m≡0(mod Tinv)if then t<=k if then{A?1 ll}L?1l=0,{G?1ll}Ll=1根據(jù)式(4)計(jì)算逆矩陣else u,v{A?1 ll}L?1l=0,{G?1ll}Ll=1計(jì)算向量,根據(jù)式(7)和式(8)計(jì)算逆矩陣end if t=t+1 end if{θl}Ll=1根據(jù)式(5)更新參數(shù)m=m+1 end While
為了說明改進(jìn)的K-FAC 算法的有效性,我們?cè)诔S玫膱D像分類數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn).實(shí)驗(yàn)中,數(shù)據(jù)集選取的是CIFAR-10和CIFAR-100 數(shù)據(jù)集[19].這兩個(gè)數(shù)據(jù)集都是由60000 張分辨率為3 2×32的彩色圖像組成,訓(xùn)練集和測試集分別有50000和10000 張彩色圖像.CIFAR-10 數(shù)據(jù)集中圖像有10個(gè)不同的類,每類有6000 張圖像.CIFAR-100 數(shù)據(jù)集中圖像有100個(gè)不同的類,每類有600 張圖像.實(shí)驗(yàn)中我們對(duì)兩個(gè)數(shù)據(jù)集的圖像都采用數(shù)據(jù)增強(qiáng)技術(shù),包括隨機(jī)裁剪和水平翻轉(zhuǎn).我們選擇動(dòng)量隨機(jī)梯度下降(SGDM)和K-FAC 算法作為對(duì)比標(biāo)準(zhǔn),在ResNet20[20]上比較了這兩個(gè)方法和改進(jìn)的K-FAC 算法的表現(xiàn).
實(shí)驗(yàn)中我們采用的深度學(xué)習(xí)框架是TensorFlow,訓(xùn)練的硬件環(huán)境為單卡 NVIDIA RTX 2080Ti GPU.實(shí)驗(yàn)中批量大小(batch-size)設(shè)置為128,動(dòng)量為0.9,最大迭代次數(shù)為39100,初始學(xué)習(xí)率SGDM 設(shè)置為0.03,K-FAC 算法和改進(jìn)的K-FAC 算法設(shè)置為0.001,學(xué)習(xí)率每16000 次迭代衰減為原來的0.1.對(duì)于K-FAC 算法和改進(jìn)的K-FAC 算法,Fisher 信息矩陣及其逆矩陣的更新頻率分別為TFIM=10,Tinv=100,阻尼為0.001.在改進(jìn)的K-FAC 算法中,我們令 α=β=0.1.對(duì)于所有的方法,我們均沒有采用權(quán)重衰減.
在表2中,我們給出了在CIFAR-10 數(shù)據(jù)集上SGDM,K-AFC 算法和改進(jìn)的K-FAC 算法的訓(xùn)練精度及時(shí)間比較,其中,K-FAC 算法給出了每次迭代均更新逆矩陣(1:1)和100 次迭代更新逆矩陣(100:100)的實(shí)驗(yàn)結(jié)果.表中第一行給出了各種算法的訓(xùn)練精度比較,其余各行別給出了各種方法每次迭代的平均訓(xùn)練時(shí)間以及測試精度首次達(dá)到89%,90%,91%,92%,93%的訓(xùn)練時(shí)間,表中最后一列給出了改進(jìn)的K-FAC 算法(100:100)比K-FAC 算法(100:100)減少的訓(xùn)練時(shí)間.因?yàn)镃IFAR-100 數(shù)據(jù)集和CIFAR-10 數(shù)據(jù)集圖像數(shù)量和分辨率相同,這兩個(gè)數(shù)據(jù)集上每次迭代的訓(xùn)練時(shí)間幾乎相同,所以我們僅給出了在CIFAR-10 數(shù)據(jù)集上的結(jié)果,在CIFAR-100 數(shù)據(jù)集也有類似的結(jié)果.
表2 SGDM,K-FAC和改進(jìn)的K-FAC 算法的訓(xùn)練精度及時(shí)間比較
從表2可以看出,K-FAC在不同的逆矩陣更新頻率下((1:1)和(100:100))的測試精度相差不大,但每次迭代均更新逆矩陣耗費(fèi)了大量的計(jì)算時(shí)間(每次迭代平均增加了2.07 s).結(jié)合之前的相關(guān)工作,在本文中我們更多關(guān)注若干次迭代更新逆矩陣的實(shí)驗(yàn)結(jié)果.因此,在后文中,我們主要基于K-FAC 算法(100:100)的實(shí)驗(yàn)結(jié)果進(jìn)行討論.
從測試精度看,改進(jìn)的K-FAC 算法與K-FAC 算法相差不大.在CIFAR-10 數(shù)據(jù)集上,改進(jìn)的K-FAC 算法的測試精度略低于K-FAC 算法,但在CIFAR-100 數(shù)據(jù)集上,改進(jìn)的K-FAC 算法的測試精度高于K-FAC算法.從訓(xùn)練時(shí)間看,SGDM 從89%到90%,K-FAC 算法從91%到92%以及改進(jìn)的K-FAC 算法從從91%到92%的訓(xùn)練時(shí)間差距較大,這是因?yàn)樵趯W(xué)習(xí)率衰減之前,測試精度在較多的迭代中變化不大,衰減后才達(dá)到了相應(yīng)的測試精度.改進(jìn)的K-FAC 算法每個(gè)迭代的平均訓(xùn)練時(shí)間與SGDM 相比,僅增加了0.006 s,比KFAC 減少了0.023 s.從到達(dá)各個(gè)測試精度的時(shí)間看,改進(jìn)的K-FAC 算法均比K-FAC 算法減少了大量的訓(xùn)練時(shí)間.比如在測試精度達(dá)到91%時(shí),K-FAC 算法比SGDM 多花費(fèi)了8 s,而我們改進(jìn)的K-FAC 算法比SGDM 減少了356 s.從表格最后一行看,SGDM 最終的測試精度達(dá)不到93%,K-FAC 算法和改進(jìn)的KFAC 算法都可以達(dá)到93%,而且改進(jìn)的K-FAC 算法減少了237 s.從這些結(jié)果可以看出,我們改進(jìn)的K-FAC算法可以達(dá)到與K-FAC 算法相近的訓(xùn)練精度,同時(shí)減少了大量的訓(xùn)練時(shí)間,而且與一階優(yōu)化方法相比在速度與精度上都具有一定的優(yōu)勢.
圖2給出了在CIFAR-10 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果,分別給出了SGDM,K-FAC 算法(100:100)和改進(jìn)的KFAC 算法(100:100)的訓(xùn)練損失,訓(xùn)練精度和測試精度隨迭代的變化曲線.在圖中可以看出二階優(yōu)化方法(KFAC 算法和改進(jìn)的K-FAC 算法)收斂速度明顯快于SGDM,改進(jìn)的K-FAC 算法與K-FAC 收斂速度相近.從訓(xùn)練損失看,所有的方法都可以達(dá)到較低的訓(xùn)練損失,SGDM的訓(xùn)練損失略高;從精度看,所有的方法都可以達(dá)到很高的測試精度,我們改進(jìn)的K-FAC 算法在前期表現(xiàn)好于K-FAC.
圖3分別給出了SGDM,K-FAC 算法和改進(jìn)的KFAC 算法在CIFAR-100 數(shù)據(jù)集上的訓(xùn)練損失,訓(xùn)練精度和測試精度隨迭代的變化曲線.從圖中可以看出,CIFAR-100 數(shù)據(jù)集和CIFAR-10 數(shù)據(jù)集有著相似的實(shí)驗(yàn)結(jié)果.但從測試精度看,改進(jìn)的K-FAC 算法好于KFAC.從這些結(jié)果我們可以看出,我們改進(jìn)的K-FAC算法與K-FAC 算法相比,有著相似甚至更好的實(shí)驗(yàn)效果,說明我們提出的Fisher 信息矩陣的逆矩陣進(jìn)一步近似的方法是有效的.
圖2 CIFAR-10 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果
在深度學(xué)習(xí)中應(yīng)用二階優(yōu)化問題面臨的一個(gè)重要挑戰(zhàn)是計(jì)算曲率矩陣的逆矩陣,由于深度神經(jīng)網(wǎng)絡(luò)擁有海量的參數(shù)導(dǎo)致其曲率矩陣的規(guī)模巨大而難以求逆.在本文中,我們基于K-FAC 算法對(duì)Fisher 信息矩陣的近似方法,結(jié)合擬牛頓方法的思想,在前期少量迭代中利用原方法訓(xùn)練,后續(xù)迭代利用新計(jì)算的矩陣信息構(gòu)造秩–1 矩陣進(jìn)行近似.利用Sherman-Morrison 公式大大降低了計(jì)算復(fù)雜度.實(shí)驗(yàn)結(jié)果表明,我們改進(jìn)的KFAC 算法與K-FAC 算法有著相似甚至更好的實(shí)驗(yàn)效果.從訓(xùn)練時(shí)間看,我們的方法比原方法減少了大量的計(jì)算時(shí)間,與一階優(yōu)化方法相比我們改進(jìn)的方法仍具有一定的優(yōu)勢.但如何在深度學(xué)習(xí)中更加有效地利用曲率矩陣的信息,得到更有效更實(shí)用的算法,仍然是在深度學(xué)習(xí)中應(yīng)用二階優(yōu)化方法面臨的重要挑戰(zhàn).
圖3 CIFAR-100 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果