徐盼盼,陳長駿,閆志文,李林超
(1.浙江省人民醫(yī)院杭州醫(yī)學院附屬人民醫(yī)院,臨床醫(yī)學工程部,杭州 310014;2.浙江大學醫(yī)學院附屬第二醫(yī)院,臨床醫(yī)學工程部,杭州 310009;3.浙江啄云智能科技有限公司,杭州 310052)
目前中國已是世界上糖尿患者人數(shù)最多的國家之一,有超過11 000萬患者且發(fā)病人口仍在增長中,專家預計到2035年發(fā)病人數(shù)將達到59 200萬。不少糖尿病患者還伴有其他并發(fā)癥,其中視網(wǎng)膜疾病就是典型的并發(fā)癥之一。調研表明,患有高血糖視網(wǎng)膜并發(fā)癥(diabetic retinopathy,DR)的病人人數(shù)約占糖尿病總人數(shù)的1/3,所以糖尿病視網(wǎng)膜病已變成國際失明人數(shù)上升的主因,該病癥會嚴重影響到病人的視力甚至導致失明,是當今致盲率最高的導火索。在臨床中,因為視網(wǎng)膜圖像的影像數(shù)據(jù)等級較多且各個級別間差別較小,醫(yī)生對此種病例的分級治療過程相對耗時,長期大量的閱片將會讓專家陷入疲憊狀態(tài),導致發(fā)生錯誤治療、漏癥等狀況的出現(xiàn),從而影響閱片的準確率。基于上述原因,通過深度機器學習技術輔助醫(yī)師對DR疾病分級鑒別確診已經(jīng)成為了近些年在該領域的一個重點研發(fā)方向,而通過該技術能夠更高效地輔助醫(yī)師開展糖網(wǎng)病治療,從而大大提高了DR疾病確診的效率和準確度,并有著巨大的臨床使用價值。對DR患者進行早期篩查和治療可以有效防止視覺損害及失明,早期確診原發(fā)疾病就可以對患者進行跟蹤隨訪,能夠有效的幫助患者保存視力、阻止DR患者失明。
近些年不少國內(nèi)外學者針對深度學習在糖尿病視網(wǎng)膜病識別問題展開了一定的研究。Saichua等[1]采用了深度學習方法對糖尿病導致的視網(wǎng)膜病變圖片進行等級分級,利用深度學習神經(jīng)元去提取不同病變等級視網(wǎng)膜特征;Gulshan等[2]在公開數(shù)據(jù)集EyePACS-1和Messidor-2的基礎上,利用深度學習方法進行臨床驗證,并制定評價指標;Valarmathi等[3]對病變視網(wǎng)膜分別采用手工提取特征和深度學習方法提取特征進行對比,論證深度學習方法提取病變視網(wǎng)膜特征的可行性;Liu等[4]將殘差網(wǎng)絡和自注意網(wǎng)絡進行對比,得到殘差網(wǎng)絡比自注意網(wǎng)絡效果好;Xie等[5]提出了Resnext殘差結構,對網(wǎng)絡的寬度和深度進行復雜化,提高分類的準確率;Bello等[6]對殘差結構的網(wǎng)絡寬度和注意力機制進行了修改,證明了殘差網(wǎng)絡識別準確率高于NAS自動搜索網(wǎng)絡;Ma等[7]驗證了基于Transformer模型的視網(wǎng)膜圖像分類方法的可行性;Hu等[8]驗證通過擠壓和激勵模塊增加通道之間的相關性,從而提高網(wǎng)絡層提取特征的空間和通道的表征能力;Rosenfeld等[9]對網(wǎng)絡的寬度、深度以及尺度進行調節(jié),驗證了模型的泛化能力和數(shù)據(jù)以及網(wǎng)絡之間的相關性;鄭雯等[10]對 ResNext50聚合殘差結構進行預訓練,結合多種數(shù)據(jù)增強策略擴充數(shù)據(jù)集,提高分類的準確率;顧婷菲等[11]使用了多通道注意力[12]選擇機制的細粒度分級方法[13-14]對糖尿病性視網(wǎng)膜病變分級。陳明惠等[15]將遷移學習技術應用于視網(wǎng)膜圖像自動分類上,達到了高效的視網(wǎng)膜病變自動分類效果。
糖尿病視網(wǎng)膜病變不同等級差異非常細微,病灶點微小且對分類精度要求較高,此外,相關影像數(shù)據(jù)比較有限。上述問題給深度學習模型用于糖尿病視網(wǎng)膜病變圖像自動分類帶來了一定的挑戰(zhàn)。根據(jù)以上分析,為了進一步提高對DR的分析準確度與魯棒性,從3個方面對模型進行了設計與優(yōu)化。
(1)多級特征殘差塊:因為殘差結構中的卷積層受卷積核限制,輸出的特征層具有局限性,無法采用多分辨率[16-17]進行特征提取,提取到的特征圖損失比較嚴重,所以多級特征殘差塊通過多組卷積在感受野和輸入層的分辨率方面進行優(yōu)化。多組卷積依據(jù)不同分辨率輸入得到的多個不同尺度的特征,挖掘的特征信息更加全面和有效。多級特征殘差塊分成2個階段:①階段一從前到后級聯(lián),將前一組通道卷積得到的特征疊加到后一組通道中,得到的特征層具有前一層像素信息,再通過卷積得到的特征層具有不同分辨率信息;②階段二從后往前級聯(lián),將后一組的卷積得到特征層疊加到前一層特征中,用1×1的卷積壓縮通道,這樣得到特征層具有多層語義信息,感受野更大。
(2)全局通道聯(lián)合注意力機制:全局通道聯(lián)合注意力機由通道注意力機制模型和全局上下文模塊兩個部分組成,得到的特征層既有通道的上下文關系,又有特征層的上下文關系,使特征層具有全局感受野,從而更好地捕獲視網(wǎng)膜病變的有效特征信息。
(3)設計了集成難例挖掘的訓練方法:集成難例訓練方法引導模型關注難分類和分類錯誤的數(shù)據(jù)集,減少對易分類數(shù)據(jù)集的過多學習,其方法能夠解決傳統(tǒng)模型訓練方法對易分類數(shù)據(jù)集和難分類數(shù)據(jù)同樣權重的學習,導致模型訓練階段對難分類數(shù)據(jù)關注不夠,對易分類數(shù)據(jù)集學習過多;而在推理階段對難分類數(shù)據(jù)集特征不夠敏感,對易分類數(shù)據(jù)集特征過于敏感。集成難例挖掘訓練方法包含了3個方面:①挖掘難分類訓練集,在模型訓練過程中在線分析訓練集,挖掘難例信息;②采用圖像融合方法將難例數(shù)據(jù)和原始訓練集進行融合,得到多類訓練集;③多類別損失函數(shù)計算:采用sigmoid方法實現(xiàn)一張圖片多個類別分類問題。現(xiàn)對主干網(wǎng)絡結構的殘差塊采用不同分辨率進行特征提取和不同感受野融合,殘差塊間采用全局和通道注意力來提高有效特征提取,以及在線難例樣本挖掘提高易錯樣本的模型學習能力。因此,研究的MA-DRNet模型可提升在不同分辨率下病灶特征提取、增大感受野、降低有效特征的損失和增加難例樣本的學習能力,從而提高模型對眼膜病變等級識別的準確率。
在多級特征殘差塊、全局通道聯(lián)合注意力機制和集成難例訓練方法3個方面進行創(chuàng)新。
(1)多級特征殘差塊是基于Resnet50基礎框架進行修改。Resnet50的殘差塊雖然能夠將輸入層特征信息和輸出層特征信息進行融合,但受3×3卷積的限制,得到特征圖具有局限性、輸出特征層對輸入層的信息沒有更深入的挖掘,挖掘的語義信息較少,輸出的卷積圖有效信息丟失嚴重,但多級特征殘差塊能夠有效解決Resnet50的殘差塊的不足。多級特征殘差塊分成兩個部分:①3×3卷積替換成多組卷積使模型能夠對同一張?zhí)卣鲌D片進行不同分辨率的有效特征信息挖掘;②每組卷積與鄰近卷積的特征層融合,擴大模型的感受野,減緩卷積所帶來的局部性。
(2)全局通道聯(lián)合注意力機制:用于殘差模型之間的連接,得到的特征層具有全局性,增加模型的感受野。全局通道聯(lián)合注意力機制在通道和特征層兩個方面對有效特征增加權重,利用損失值對模型反向傳播,引導模型向有效信息學習,減少模型對噪聲等無效信息的關注。
(3)集成難例訓練方法:在模型訓練過程中,對訓練集進行分析,得到難例樣本;采用圖像融合的方法,將難例數(shù)據(jù)集和原始數(shù)據(jù)集進行融合,提高難例數(shù)據(jù)集出現(xiàn)的頻率,從而提高模型對難例數(shù)據(jù)集的學習能力;采用多類損失函數(shù)方法,解決傳統(tǒng)分類模型一張圖片只能出現(xiàn)一類的難題,實現(xiàn)了圖像融合。不僅如此,多類別損失計算減少易分類訓練集對模型的影響,增加模型對難例樣本的學習。圖1給出了本文對糖尿病視網(wǎng)膜病變等級識別的整體方案和分類模型框架。包括對數(shù)據(jù)集進行加載,模型讀取圖片信息;然后對數(shù)據(jù)進行預處理,采用最小比例縮放方法對數(shù)據(jù)進行縮放;接著將縮放信息輸送到模型中;最后對網(wǎng)絡輸出層采用sigmoid操作得到預測值。
圖1 糖尿病視網(wǎng)膜病變等級識別的整體方案圖Fig.1 Overall plan of grade recognition of diabetes retinopathy
模型的基礎架構是Resnet50網(wǎng)絡,先對輸入塊做卷積操作,然后包含4個殘余誤差塊,最后再進行全連接操作以便于完成分析任務,Resnet50包含50個conv2d操作。其中,殘差單元的算法如圖2所示,通過殘差學習解決了深度網(wǎng)絡的梯度彌散問題。但這樣的殘差結構受到神經(jīng)網(wǎng)絡卷積核的限制,得到的特征值是從固定范圍提取,僅能代表卷積層部分信息?;谏鲜鰡栴}進行修改,將在通道方面進行分組卷積,如圖3所示。
圖2 原始殘差單元Fig.2 Original residual element
圖3 多級特征殘差結構Fig.3 Multistage characteristic residual structure
(1)先依次對卷積通道進行卷積,將上一組卷積輸出的特征層疊加到下一組卷積特征層。得到當前組卷積的輸入信息擁有本層輸入層信息也有上一組經(jīng)過分組卷積得到的特征信息,這樣形成的分組卷積的輸入層具有不同分辨率信息,可以讓模型對不同分辨率進行信息挖掘,相當于采用不同卷積核對輸出層進行數(shù)據(jù)挖掘,得到不同感受野特征層。
(2)反方向疊加,從后向前操作,將分組卷積向前依次疊加,大感受野特征值重新疊加到前一組分組卷積特征層;因此,多尺度[18]特征殘差塊增大網(wǎng)絡感受野,使每組分組卷積除了原始卷積3×3的感受野以外還增加后一層感受野信息。除此之外,與其他分組卷積特征值相加,得到其他通道信息,捕獲不同通道糖尿病視網(wǎng)膜信息。因此,多尺度特征可以增加捕獲細節(jié)和獲得不同感受野信息,擴大各個網(wǎng)絡層的感受野覆蓋范圍。
殘差結構輸出xl+1由上一個階段特征層xl直接映射和本階段的特征輸出層兩部分組成。
殘差結構表達式為
xl+1=xl+F(xl,wl)
(1)
式(1)中:xl表示上一個階段特征層;F(xl,wl) 表示卷積操作;xl表示當前階段特征層;wl表示卷積核。
多尺度特征殘差結構表達式為
y0=x0
(2)
y11=f1(x1,w11)
(3)
y12=f1(x2+y11,w12)
(4)
y13=f1(x3+y12,w13)
(5)
y22=f2[f3(y12+y13),w22]
(6)
y21=f2[f3(y11+y22),w21]
(7)
yl=f3(y0,y21,y22,y13)
(8)
式中:x0、x1、x2、x3表示將輸入層按通道分成4組;f1表示3×3卷積;f2表示1×1卷積;f3表示特征層按通道疊加;yl表示特征層輸出;w11、w12、w13表示3×3卷積核,w22、w21表示1×1卷積核。
從式(2)~式(8)可以得到輸出的特征層包含了每組通道的特征信息,其中第2、3、4組先依次為后一組分組卷積增加感受野和特征值,然后反向傳輸,將后一組特征按通道疊加到上一組特征中,采用1×1卷積對通道進行減少到原來一樣,這樣第2、3、4在不減少特征損失的前提下,擴大特征提取范圍,使殘差結構更多關注有效特征值。
由于視網(wǎng)膜圖像的復雜性,對于病灶的識別僅僅用到病灶所在區(qū)域的局部特征[19-20]是不充分的,往往還需要依賴周圍區(qū)域的特征甚至全局區(qū)域[21]的整體特征,而卷積神經(jīng)網(wǎng)絡是通過滑窗的方式分別提取局部的信息,難以建立信息之間的依賴關系。此外,不同特征通道之間表達的信息側重點不同,對于病灶的識別往往更注重于邊緣和紋理特征,這就需要模型提高對相應特征通道的關注[22]。因此,本文提出了一種全局通道聯(lián)合注意力模塊,使模型可以同時具備捕獲長距離依賴關系和通道注意力的能力。全局通道聯(lián)合注意力模塊包含通道注意力模塊與全局上下文模塊,具體實現(xiàn)如圖4所示,圖中的C、H、W分別為特征圖的通道數(shù)、高、寬,操作原理定義為
圖4 全局通道聯(lián)合注意力模塊Fig.4 Global channel joint attention module
Zi=XiS(Bs+Bc)
(9)
Bs=R[wnG(Xi)]
(10)
(11)
先使用全局平均池化將特征圖中每個通道的特征固定成相同尺度于上下文建模,然后使用1×1卷積計算每個通道的重要程度。
模型訓練結束后使用難例挖掘方法對難分樣本進一步鞏固訓練,加強模型對于易錯樣本的學習,本文提出一種綜合多階段的難例挖掘與訓練方法,具體方法如下。
(1)選擇最后k個epoch模型結果,k一般取總的訓練epoch的1/10。
(2)使用這k個模型分別在訓練數(shù)據(jù)上前傳傳播得到預測結果。
(3)根據(jù)便簽樣本統(tǒng)計k個模型在訓練集中分別預測錯誤的圖像列表{L1},{L2},…,{Lk}。
(4)統(tǒng)計k個列表中圖像列表的交集,重復采樣2次得到列表Ln。
(5)統(tǒng)計k個列表中出現(xiàn)次數(shù)大于2次的圖像,采樣1次得到列表Lu。
(6)Lk、Ln、Lu這3個列表混合的一起,得到最終的混合難例列表Ls。
(7)修改采樣器,隨機抽取Ls中的數(shù)據(jù)與訓練集采用圖像融合技術進行數(shù)據(jù)預處理得到mix_data。
(8)將mix_data輸入的網(wǎng)絡得到輸出特征層。
(9)輸出特征層采用sigmoid預測圖片的類別和得分。
(10)將預測結果與標簽進行損失計算,反向傳播,優(yōu)化模型。
4.2 圖像融合技術
圖像融合技術把不同類型的圖像按比例融合,達到擴充訓練數(shù)據(jù)集,如圖5所示。數(shù)據(jù)混合計算方法為
圖5 圖像融合前后對比圖Fig.5 Image fusion before and after comparison
λ=B(a,b)
(12)
MBx=λBx1+(1-λ)Bx2
(13)
MBy=λBy1+(1-λ)By2
(14)
式中:B為貝塔分布;λ為混合系數(shù);α和β設置為0.5;MBx為混合后的樣本數(shù)據(jù);MBy為數(shù)據(jù)集對應的標簽;Bx1、Bx2分別對應兩種不同數(shù)據(jù)集。
模型的輸出采用one-hot的編碼方式,為了使模型擁有更豐富的特征信息,擴大類間距,采用多類別分類頭,訓練時將同一個batch內(nèi)的圖像采用圖像融合的方式融合到一起,那么同一張圖像便擁有了多個圖像的特征信息,訓練階段編碼方式如圖6所示,分類頭輸出的神經(jīng)元個數(shù)與分類任務的類別數(shù)相同,對特征圖采用sigmoid函數(shù)進行損失計算,也就是說對每張圖每個類別進行預測,這樣就可以解決一張圖片有多個類別的問題。采用分類多類別數(shù)據(jù)集訓練可以讓模型更好的關注到糖尿病病變不同等級的區(qū)別,使模型捕獲更細粒度等級的特征。
圖6 多類別分類頭Fig.6 Multi-category classification header
首先,圖像送入模型中進行推理,輸出層的激活函數(shù)使用Sigmoid函數(shù)得到每個類別的置信度,函數(shù)表達式為
(15)
(16)
最后,通過對模型進行反向推導得到模型參數(shù)的偏差,采用Adam算法對模型參數(shù)優(yōu)化,將梯度動量并入梯度指數(shù)加權估計,使用偏置修正非中心的二階矩估計和動量項一階矩。
本文所使用的是DR_data數(shù)據(jù)集采用Kaggle(2014)和MESSIDON(French Ministry,2014)公開數(shù)據(jù)集。其中Kaggle數(shù)據(jù)集是由法國EyePacs眼底及視網(wǎng)膜平臺免費提出,該網(wǎng)絡平臺匯集了多個醫(yī)院的視網(wǎng)膜圖像,共計數(shù)萬張圖像涵蓋了不同的成像環(huán)境。而MESSIDON眼膜數(shù)據(jù)則是由法國國防部研究部的篩查項目所提出,收集了3家不同眼科機構數(shù)據(jù)。視網(wǎng)膜類別圖像如圖7所示。
圖7 數(shù)據(jù)類別圖Fig.7 Data category
本文實驗代碼基于python3.7,pytorch版本為1.7,torchvision版本0.8.0,Linux系統(tǒng)版本為20.04,GPU配置為Rtx3090。在網(wǎng)絡訓練過程中,Batch size是32,學習率是0.1,優(yōu)化器是adam,學習率衰減是0.000 1,線程為8。
選用特異性SP(specificity)、敏感性SE(sensiti-vity)、準確性AC(accuracy)作為指標評估。其計算方法分別為
(17)
(18)
(19)
式中:TP、FP、FN、TN分別代表真陽性、假陽性、假陰性和真陰性。
以Resnet50為基線模型分別對多級別殘差結構、全局通道聯(lián)合注意力模塊和難例挖掘所做的優(yōu)化進行消融實驗,實驗在Kaggle(2014)和MESSIDON兩個數(shù)據(jù)集上的平均準確率如表1所示。
表1 不同基線模型平均準確率對比結果Table 1 Comparison results of average accuracy of different baseline models
本文最終優(yōu)化后的模型在各個類別上的預測結果準確率如表2所示。
表2 本文模型類別結果準確率Table 2 This paper model classification results accuracy table
本文中采用sigmoid得到分類結果,依據(jù)式(17)~式(19)對DR數(shù)據(jù)集的五分類和二分類各項評估指標進行計算。最終本文模型在測試集上的特異性為99.02%,敏感性為98.26%,準確率為98.87%,優(yōu)于現(xiàn)有各方法,如表3所示。
表3 不同算法對比實驗結果Table 3 Compare the experimental results with different algorithms
提出了MA-DRNet模型解決的了傳統(tǒng)卷積神經(jīng)網(wǎng)絡對于糖尿病視網(wǎng)膜病變復雜特征學習困難,準確率低的問題。提出的多級特征殘差塊擴充了模型的感受野,加強了模型對于小尺度病灶的學習能力以及對于尺度的魯棒性;優(yōu)化的全局通道聯(lián)合注意力機制同時實現(xiàn)像素長距離依賴關系捕獲,提升了模型對于復雜病灶的表征效果;設計的集成難例挖掘與訓練方法,改善了模型對于易錯樣本學習效果差的問題。使用本文MA-DRNet模型在Kaggle和MESSIDON兩個數(shù)據(jù)集上訓練和測試,在測試集上分類準確率達到98.87%,特異性達到99.02%,敏感性達到98.26%,超過目前已知同類方法。此外,本模型所提出的方法可以即插即用到其他卷積神經(jīng)網(wǎng)絡中,可以大幅提升模型對于糖尿病視網(wǎng)膜病變分級的準確率。本文方法對于糖尿病視網(wǎng)膜病變的自動分級診斷,提升眼科疾病篩查的準確率和效率方面具有重要意義。