鄭德重, 楊媛媛, 黃浩哲, 謝 哲, 李文濤
(1. 中國科學院上海技術物理研究所 醫(yī)學影像信息學實驗室, 上海 200080; 2. 中國科學院大學,北京 100049; 3. 復旦大學附屬腫瘤醫(yī)院, 上海 200032)
在處理機器學習分類問題時,模型準確率的提高總是會被優(yōu)先關注.較高的準確率固然很重要,但在有限樣本的情況下,學習到的模型準確率可能很高,但模型的可靠性并不一定很高.在統(tǒng)計學中,置信度是度量系統(tǒng)可靠性的一個典型指標.置信度的重要性在于,如果一個決策支持系統(tǒng)對某個樣本進行預測的信心太低,則可能需要人類專家參與決策過程.在實際應用中,置信度還具有識別未知類別樣本的能力.因為當分類時,如果置信度較低,則表明要鑒別的樣本與建立模型時用到的樣本差異度較大.借助置信度能夠擴大樣本訓練范圍,通過再訓練改善模型,提高模型的泛化能力[1-4].因此,良好的置信度應該是模型設計的一部分[1].在部署任何機器學習分類模型時,好的模型不僅要求有較高的準確率,而且還要能以較高的置信度進行正確分類.
機器學習中大多數生成的模型本質上都是概率模型,可以直接得到這樣的置信度.但大多數判別模型無法直接獲得每個類的預測概率,而是將相關的非概率分數作為一種替代,如支持向量機(SVM)分類器中的最大間隔[1,5-6].在評估一個神經網絡模型的好壞時,通常會使用各種不同類型的分數來衡量模型的置信度.常用的方法是將最后一級輸出單元通過Softmax軟件歸一化.此外,還可以利用輸出單元的熵來計算,當預測某個樣本的不確定性越低時,熵越小.雖然這些從輸出端得到的分數與置信度相關,但使用這些分數度量置信度也存在一些缺陷,一些不可察覺的擾動可能改變神經網絡的輸出值.文獻[7]通過實驗在圖像分類樣本中加入噪聲擾動,原本能夠正確分類的樣本在加入擾動后可以得到完全相反的預測結果,而加入噪聲的圖像在人的視覺觀察中感覺不到任何變化.神經網絡相比于人類在對數據的理解方面存在巨大差異,可能存在某些反直覺的情況“盲區(qū)”,也間接說明了神經網絡可能存在某些人類難以覺察的不確定性,這種不確定將會直接影響輸出結果和置信度[7-8].對于分類而言,將最后一級單元通過Softmax軟件獲得的概率最大值視為分類置信度是不準確的,因為這種方式忽略了與其余類預測概率間的關系,與真正的置信度之間有時存在著一定的偏差[3,9].既然從模型的外部計算出來的置信度不一定能夠代表其真實的概率估計,那么可以嘗試從模型內部入手.
為了獲得一個對于神經網絡分類模型可靠的置信度分數,許多研究者將注意力集中在神經網絡的嵌入階段,這些嵌入層被證明可以在許多相關任務中提供更好的語義表示[10-12].使用這種語義表示,通過估計嵌入空間中樣本的局部密度來計算置信度分數,進而可以計算樣本屬于不同類別的概率.基于此,本文在嵌入空間提出一種基于距離置信度分數(DCS)的計算方法來度量模型的置信度.此方法不依賴于特定的分類模型,可以嵌入任何分類器中進行置信度計算.通過實驗證明所提方法不僅可以用在單一模態(tài)分類模型中,還可對此進行擴展,將其用在多模態(tài)分類模型的置信度度量中.綜上所述,本文的主要貢獻如下:① 提出一種不依賴于特定分類模型的置信度度量方法,該方法不僅可以用在單模態(tài)分類問題,還可以用在多模態(tài)分類問題中;② 對于多模態(tài)分類問題,該方法可以量化評估單模態(tài)數據對于模型最終決策的影響,同時還可以知道不同模態(tài)信息對于最終決策時的重要程度差異.
本文提出的基于距離置信度分數,主要借鑒以往兩方面的研究:神經網絡置信度分數和多模態(tài)融合研究.因此,下文將從這兩個方面介紹相關工作.
Bayes模型在數學上提供了一種用來計算置信度的基礎框架.文獻[13-14]使用神經網絡上的參數計算后驗分布,用于估計預測不確定性來進行置信度的度量.文獻[15]利用Bayes網絡模型的拓撲結構結合多模態(tài)信息對非常規(guī)性突發(fā)事件的可能性進行量化評估.雖然利用Bayes方法來計算置信度的數學理論成熟,但在實際應用中實現起來相對困難,且計算成本高.文獻[16]提出在模型測試時用dropout操作作為Bayes網絡的一種簡單替代,通過輸出結果觀察模型的不確定性.文獻[17]提出使用對抗訓練來改進網絡基于熵分數的不確定性度量方法.模型置信度通常是從輸出端的激活函數或其歸一化中計算得到的,這些方法大多是通過對模型輸出端進行外部觀測,并將觀測結果用于計算模型置信度.鑒于通過外部輸出計算置信度有較多不足之處,希望能找到一種通過模型內部觀測的方法,即找到一種能夠代表真實概率估計的方法來計算置信度.
在度量學習研究領域中,如人臉識別、圖像檢索等,通過嵌入的方法可以學習從原始特征空間到一個低維稠密向量空間(嵌入空間)的映射,在該空間中樣本的相似度可以通過距離進行度量.文獻[18]在使用ImageNet數據集訓練一個深層網絡時,通過嵌入得到一個圖像語義豐富的表示,并在此基礎上進行分類.文獻[10]基于深度嵌入的度量學習思想,在有限標記的成對樣本之間進行相似性學習,完成不同圖像的匹配任務.文獻[11]通過提取圖像特征學習高層次的嵌入語義來實現圖像壓縮.文獻[1]還通過對抗性實驗證明嵌入空間不僅含有豐富的語義信息,而且還具有一定的抗干擾能力.由此可知,通過嵌入方式可以將網絡提取的特征從原來的特征空間映射到一個稠密的、可度量的嵌入空間中.在這個嵌入空間對樣本進行局部概率密度估計,有望找到一種可以度量模型的置信度分數.
融合不同來源的信息主要有3種方式:早期融合、中期融合和后期融合.早期融合是在訓練模型之前,將不同模態(tài)的特征串聯起來,然后從串聯特征中進行學習.中期融合首先是對各個單模態(tài)數據的特征進行一些初步學習,然后將學到的初級特征通過第2階段融合加工進一步學習,最后將這些學習到的綜合特征用于最終決策.中期融合方式目前大多是通過深度學習來實現的[19].后期融合方式是針對各種不同模態(tài)數據獨立使用不同算法,然后根據任務特點使用一些技術組合方式進行最終決策.早期融合方式的主要優(yōu)勢是可以識別不同模態(tài)特征之間的關系,但該方式無法充分利用每種模態(tài)數據中自己的模式.早期融合只適用于相同類型數據間的融合,不同類型之間的數據不能直接融合,例如圖像數據和文本數據.此外,早期融合方式著重于組合不同模態(tài)的特征,因此其通常具有很高的特征與樣本比,容易導致分類時模型過擬合.后期融合方式與中期融合方法相比,后期融合方法實現更簡單,但無法充分利用不同模態(tài)間的交互信息,只能通過每種模態(tài)信息獨立判斷后進行綜合決策.中期融合方式目前在利用模式內信息和模式間交互信息方面具有一定的優(yōu)勢,并且可以充分利用深度學習強大的特征提取能力[20].本文提出的基于距離置信度分數來評估多模態(tài)分類模型置信度就是通過中間融合方式融合信息的.
接下來,先介紹兩種通常使用的從外部輸出端得到的模型置信度分數.然后再介紹所提的從內部嵌入空間評估模型的置信度分數計算方法,最后介紹這種新的置信度分數在多模態(tài)分類模型中的構建.
給定一個訓練好的模型,通常使用以下兩種分數來評估分類的置信度:基于最大距離置信度分數(MMCS)和基于熵置信度分數(ECS).文獻[21]的實證研究表明,對外部輸出而言,用這兩種方法是相對有效的評估模型置信度的方法,文獻[1]也曾用這兩種分數評估模型的置信度.兩種分數定義如下:① 基于最大距離的置信度分數.歸一化后,網絡輸出層中的最大激活單元.② 基于熵的置信度分數.網絡輸出層中激活單元的(負)熵.
所提出的基于距離置信度分數的主要思路是借鑒度量學習方法,在網絡特征提取后添加一層嵌入層,將原來網絡中提取的特征進行映射,映射到一個語義豐富且可以度量的稠密空間中.在該嵌入空間中估計樣本的局部密度,進而計算模型置信度,如圖1所示.由圖1可知,左側特征提取部分用來提取樣本特征;右側兩層全連接層,一個用于嵌入獲得樣本的向量表示,一個用于映射到輸出以獲取相應的預測值.
圖1 基于距離置信度分數的計算示意圖Fig.1 Schematic diagram of distance confidence score calculation
(1)
圖2 最近k個點的密度估計Fig.2 Estimation of density of the nearest k points
(2)
式中:max (·)為測試樣本xi預測類別最大的分數,即最有可能的分類.
2.2.2利用中心損失提高嵌入效果 在度量學習應用中,鑒別的樣本對象之間差異度相對較小,其分類模型要在能夠對其進行細粒度鑒別的同時保持穩(wěn)健性.早期主要是通過交叉熵損失來訓練優(yōu)化模型的,之后有學者提出了三重態(tài)損失訓練模型,但在訓練過程中三重態(tài)樣本配對組合的差異度會影響模型的學習速度[22].文獻[23]提出將中心損失用于面部識別,根據中心損失的梯度更新每個mini-batch中心,作為三重態(tài)損失的一種替代取得了良好的效果.文獻[24]在少樣本學習中使用了類似的方法,不斷更新mini-batch的中心來進行優(yōu)化,在場景識別任務中取得了不錯的效果.中心損失優(yōu)化時,最小化具有相同標簽的樣本到其樣本中心之間的距離,將屬于同一類的數據點聚集在一起,以獲得在嵌入空間更好的向量表示[24].為了提高嵌入表達效果,使用中心損失來優(yōu)化模型.中心損失可表示為
(3)
式中:Lso為交叉熵損失;Lcen為中心損失;f(xi)為第i個訓練樣本通過網絡后得到的高維特征向量;hci∈RD為ci的樣本中心,ci為xi的樣本類別標簽,xi∈RD,D為特征向量的維度;M為mini-batch的樣本數量;λ為超參數.
圖3 基于距離置信度分數的多模態(tài)分類網絡構建示意圖Fig.3 Schematic diagram of multimodal classification network construction based on distance confidence score
由上述可知,嵌入層添加在模型的特征提取模塊之后,對于多模態(tài)分類模型可以使用相同的方法在各自模態(tài)特征提取后添加嵌入層用于計算置信度,如圖3所示,其中:N為輸入信息序號.在單一模態(tài)分類中由于信息源只有一個,不用考慮模式中特征重要程度的差異.但在多模態(tài)分類中,由不同信息來源間的模式提取到的特征重要程度存在差異,因此引入注意力機制.注意力機制最早在計算機視覺任務中提出,隨后在自然語言處理領域也開始逐漸應用,隨著BERT(Bidirectional Encoder Representation from Transformers)模型和GPT(Generative Pre-Training)模型在該領域中取得顯著的效果,人們也越來越注意到注意力機制.注意力機制可以幫助模型將提取到的特征賦予不同權重,對關鍵、重要信息進行強化,幫助模型做出更加準確的判斷[25-28].在多模態(tài)分類網絡的特征提取階段,為了強化不同模態(tài)提取自己的關鍵信息,在各自模態(tài)中做了注意力機制的處理.在提取圖像特征時,使用了通道注意力和空間注意力機制[29];在對文本類結構化信息提取時,使用了自注意力機制[27];最后,在各自模態(tài)信息特征提取完成后再添加一個嵌入層,獲取各自模態(tài)的高維向量表示,用來計算多模態(tài)分類任務中單一模態(tài)信息的置信度.在多模態(tài)分類網絡的特征融合階段,將各模態(tài)信息進行連接,并將連接后的所有信息再次嵌入,對多模態(tài)信息融合信息進行再次學習,其嵌入向量表示可以用來計算特征融合后的置信度.
圖4 MNIST分類網絡Fig.4 MNIST classification network
在本節(jié)中,將通過3個實驗任務來評估所提置信度分數.3個任務分別為:單模態(tài)分類任務MNIST數據分類、單模態(tài)分類任務CIFAR-10數據分類、多模態(tài)分類任務肺部腺癌數據分類.上述提到的需進行比較的3種置信度分數分別為:① 外部輸出得到的基于最大距離的置信度分數;② 外部輸出得到的基于熵的置信度分數;③ 所提出的通過內部嵌入得到的基于距離的置信度分數.
(1) MNIST數據分類.手寫數字數據集,該數據包含6×104個訓練集示例,1×104個測試集示例,是美國國家標準與技術研究所(NIST)數據集合的子集.
(2) CIFAR-10數據分類.由10個類的6×104張32像素×32像素的彩色圖像組成,每個類包含 6×103張圖像,有5×104張訓練圖像和1×104張測試圖像.
(3) 肺部腺癌數據分類.來自一家三甲醫(yī)院采集的肺腺癌數據,包含 1 675 個樣本,其中532例浸潤性肺腺癌和 1 143 例非浸潤性腺癌.每個樣本數據有3種模態(tài)數據:高分辨計算機斷層掃描(HRCT)圖像數據、患者的結構化臨床基本信息和血液檢查信息.
3.2.1MNIST單模態(tài)分類 該任務中,使用了一個由6層卷積層和2層全連接層構成的網絡進行訓練,如圖4所示.其中:每個卷積層的卷積核參數用符號表示,如32@5×5表示32個5×5的卷積核.第1層全連接提取樣本向量表示用于估計概率密度,進而計算所提出的DCS.第2層全連接輸出用于計算MMCS和ECS.該實驗分別使用了交叉熵損失與中心損失來進行優(yōu)化比較.
3.2.2CIFAR-10單模態(tài)分類 對于CIFAR-10分類任務,使用了常規(guī)的ResNet50模型的特征提取器和2層全連接層構成的網絡進行訓練,如圖5所示.其中:z為每個殘差模塊的輸入;sg(g=1, 2, 3, 4)為殘差模塊;RelU為激活函數.模型首先提取圖像特征,然后經過2層全連接,第1層將ResNet50模型提取特征進行嵌入,用來獲取樣本的向量表示,進而計算所提出的DCS.第2層全連接輸出用于計算MMCS和ECS.該實驗中,同樣使用了交叉熵損失和中心損失來進行優(yōu)化比較.
圖5 CIFAR-10分類網絡Fig.5 CIFAR-10 classification network
圖6 基于注意力機制的圖像特征提取Fig.6 Image feature extraction based on attention mechanism
3.2.3肺部腺癌多模態(tài)分類 肺部腺癌多模態(tài)數據包含1組圖像數據和2組結構化文本數據.對多模態(tài)數據進行分類的網絡由兩部分組成:不同模態(tài)信息特征提取和多模態(tài)特征融合決策.
在特征提取部分,針對圖像數據使用了添加注意力機制的ResNet50網絡結構.在ResNet50網絡結構的基本殘差模塊中,添加通道注意力和空間注意力兩種注意力模塊,用以提高圖像重要部位的特征提取能力,如圖6所示.其中:C為卷積核的通道數量;G為卷積核深度;H、W分別為卷積核的高和寬;ωch為通道注意力輸出權重;ωsp為空間注意力輸出權重;Sigmod為轉換函數.對于另外兩組結構化文本數據,使用了多層感知機提取特征,同時使用了自注意力模塊來提高重要信息的提取能力,如圖7所示.其中:ωse為自注意力輸出權重;tanh為激活函數.
圖7 基于注意力機制的結構化文本特征提取Fig.7 Structured text feature extraction based on attention mechanism
多模態(tài)特征融合決策部分如圖8所示.首先, 將不同模態(tài)提取來的特征進行第1次嵌入,該嵌入空間的特征可以用來計算所提出的DCS,該分數可以反應不同模態(tài)信息的置信度.然后,將這些高維向量進行拼接并進行第2次嵌入,第2次嵌入空間的高維特征向量可以用來計算融合特征的DCS.最后,通過一層全連接進行輸出,輸出的向量用來計算MMCS和ECS.在該實驗中,使用中心損失來對模型進行優(yōu)化.
圖8 多模態(tài)特征融合Fig.8 Multimodal feature fusion
Brier分數(BS)是一種用來評估模型預測概率準確性的指標,是一種成本函數[30].Brier分數越低,其預測概率越準確,模型不確定性越低,置信度更高;反之,則置信度更低.Brier分數的取值范圍為0~1.二分類Brier 分數的計算公式如下:
(4)
式中:Pi為預測概率;oi為二分類預測輸出值,oi∈{0, 1}.對于多分類Brier 分數BSmut,其計算公式如下:
(5)
(6)
q=0, 1, …,Q-1
式中:Pij為多分類預測概率;oij為多分類預測輸出值;Q為預測輸出值可能的數量, 如10分類,則Q=10.
實驗對每個模型進行30次訓練迭代,計算每個訓練迭代次數中的3種置信度分數:由外部輸出計算的MMCS、ECS和由內部嵌入計算的DCS,觀察其變化規(guī)律.在3個任務中選擇訓練出來的最佳模型,比較所獲得模型的性能指標:準確率、接受者操作特征曲線下面積(AUC)、Brier分數.
由熵的定義可以知道,熵是預測結果不確定性的度量,不是預測每種可能性的度量分數,無法計算其Brier分數.基于熵的分數只用來觀察其變化規(guī)律,不計算Brier分數,所以在實驗中約定將外部輸出得到的MMCS作為外部Brier分數(BSo),將內部嵌入得到的DCS作為內部Brier分數(BSI).
圖9 MNIST數據集上的模型準確率和AUC隨E的變化曲線Fig.9 Model accuracy and AUC versus E on MNIST dataset
3.5.1MNIST 模型訓練中,準確率A、AUC隨訓練迭代次數(E)的變化規(guī)律,如圖9所示.由圖9可知,隨著訓練次數的增加,模型準確率和AUC逐步提高,最后趨于穩(wěn)定,使用中心損失優(yōu)化可以得到更高的準確率和AUC.3種置信度分數MMCS、ECS和DCS隨E的變化曲線,如圖10所示,其中:δ為置信度分數.由圖10可知,隨著E的增加,從輸出端得到的MMCS和從內部得到的DCS都是逐漸增大后趨于穩(wěn)定的,兩者最后趨于相同,而ECS則是逐漸減小后趨于穩(wěn)定的(見圖10(a)).通過變化曲線的一階差分可以知道,DCS和ECS正相關(見圖10(b)),DCS與ECS負相關(見圖10(c)).3種置信度分數間的相關系數如表1所示,其中:R為線性相關系數.
圖10 MNIST數據集上3種置信度分數隨E的變化曲線Fig.10 Three kinds of confidence scores versus E on MNIST dataset
表1 MNIST數據集上3種置信度分數間的相關系數
當訓練穩(wěn)定后,使用不同損失函數得到的最佳模型結果如表2所示.使用中心損失優(yōu)化可以得到準確率和AUC,并且通過內部計算嵌入得到的Brier分數更低,反映出通過內部參數計算出來的置信度分數更加接近真實情況.
表2 MNIST數據集上由不同損失函數訓練獲得的模型性能
3.5.2CIFAR-10 模型訓練中每個E的準確率、AUC隨E的變化規(guī)律,如圖11所示.由圖11可知,隨著E的增加, 模型準確率和AUC逐步提高,最后趨于穩(wěn)定,使用中心損失優(yōu)化可以得到更高的準確率和AUC.3種置信度分數隨E的變化曲線如圖12所示.由圖12可知,隨著E的增加,從輸出端得到的MMCS和從內部得到的DCS都是逐漸增大最后趨于穩(wěn)定, 最后兩者趨于相同, 而ECS則是逐漸減小后趨于穩(wěn)定的(見圖12(a)).通過變化曲線的一階差分可以知道,DCS和ECS正相關(見圖12(b)),DCS與ECS負相關(見圖12(c)).3種置信度分數間的相關系數如表3所示.
圖11 CIFAR-10數據集上的模型準確率和AUC隨E的變化曲線Fig.11 Model accuracy and AUC versus E on CIFAR-10 dataset
表3 CIFAR-10數據集上3種置信度分數間的相關系數
圖12 CIFAR-10數據集上3種置信度分數隨E的變化曲線Fig.12 Three kinds of confidence scores versus E on CIFAR-10 dataset
當訓練過程穩(wěn)定后,使用不同損失函數得到的最佳模型結果如表4所示.與MNIST類似,使用中心損失優(yōu)化可以得到準確率和AUC,并且通過內部計算嵌入得到的Brier分數更低,反映出通過內部參數計算出來的置信度分數更加接近真實情況.
表4 CIFAR-10數據集上由不同損失函數訓練得到的模型性能
3.5.3肺部腺癌 對于肺部腺癌多模態(tài)數據分類任務,不再對優(yōu)化器方面進行比較,該任務全部都使用中心損失優(yōu)化以獲得更好的嵌入表示.訓練中模型的準確率和AUC,如圖13所示.由圖13可知,隨著E的增加,模型的準確率、AUC逐步提高,最后趨于穩(wěn)定.當多模態(tài)數據加入后,相比于原來的單一模態(tài)圖像數據,模型性能得到了提高.通過由輸出端得到的MMCS、ECS和由內部嵌入得到的DCS隨E的變化如圖14所示.通過變化曲線的一階差分可以知道,DCS和ECS正相關(見圖14(b)),DCS與ECS負相關(見圖14(c)).3種置信度分數間的相關系數如表5所示.
當訓練穩(wěn)定后,使用不同損失函數得到的最佳模型表現如表6所示.由表6可以看到,多模態(tài)數據可以增加模型分類的準確率、AUC,并且通過內部計算嵌入得到的Brier分數更低,反映出通過內部參數計算出來的置信度分數更加接近真實情況.
圖13 肺部腺癌數據集上的模型準確率和AUC隨E的變化曲線Fig.13 Model accuracy and AUC versus E on adenocarcinoma dataset
圖14 肺部腺癌數據集上3種置信度分數隨E的變化曲線Fig.14 Three kinds of confidence scores versus E on adenocarcinoma dataset
表5 肺部腺癌數據集上3種置信度分數間的相關系數
表6 肺部腺癌數據集上的多模態(tài)分類模型性能
表7 基于距離置信度分數的多模態(tài)數據Tab.7 Multimodal data based on distance confidence score
3.5.4結果分析 通過上述3組不同的實驗數據可以知道,使用中心損失可以在獲得更好的嵌入表示的同時提高模型的性能(準確率、AUC和置信度).另外,所提通過嵌入得到的基于距離的置信度分數與輸出得到的基于最大距離的置信度分數和基于熵的置信度分數一樣可以作為一種度量模型的置信度方法,且所提方法更能真實地反應概率預測情況.此外,相比兩種由外部參數計算得到的置信度分數而言,在處理多模態(tài)數據分類時,所提出的基于距離的置信度分數不僅可以獲得模型整體的置信度,還可以獲得多模態(tài)數據基于自身信息在判斷時的置信度,并可以量化不同模態(tài)信息的重要程度.
本文提出一種在嵌入空間基于距離的置信度分數計算方法來度量模型的置信度.該方法在處理單一模態(tài)分類任務時,與其他通過模型輸出端計算置信度分數方法相似,可以作為一種度量模型置信度的手段.在處理多模態(tài)融合分類任務時,不僅可以用來度量模型整體的置信度,還可以用來評估和量化多模態(tài)數據對于模型最后判斷時的置信度影響,知道各種模態(tài)數據對于決策的重要程度.這一點在實際應用中對模型可靠性和可解釋性都有要求的場合中具有重要意義.