欒迪,周廣證
(南京理工大學紫金學院 計算機學院,江蘇 南京 210046)
手寫數(shù)字識別目前得到了廣泛的研究,例如,如果能對學生的日常作業(yè)及試卷做出高質(zhì)量的自動識別,做到線上自動批閱或者判分,將大大提高教師的工作效率和質(zhì)量。本文將嘗試使用LSTM網(wǎng)絡(luò)(long-short term memory network,長短時記憶網(wǎng)絡(luò))結(jié)合注意力機制對Minist數(shù)據(jù)集進行識別。Minist數(shù)據(jù)集是一個手寫數(shù)字數(shù)據(jù)庫,它有60000個訓練樣本集和10000個測試樣本集,是NIST數(shù)據(jù)庫的一個子集。每個樣本都是一張28×28像素的灰度手寫數(shù)字圖片,且每個樣本都對應(yīng)著一個唯一的標簽[1-2]。
LSTM是當前最有效的基于長時記憶的神經(jīng)網(wǎng)絡(luò)識別算法。它是對RNN(Recurrent Neural Network,循環(huán)神經(jīng)網(wǎng)絡(luò))的改進,LSTM和RNN一般用來處理序列信息,在文本、語音、視頻等具有上下文關(guān)聯(lián)的識別和預測場景中識別精度很高。本文將手寫數(shù)字的圖像看作以行為單位的數(shù)據(jù),對于特定的數(shù)字,各行之間的信息顯然具有強相關(guān)的聯(lián)系。LSTM在接收當前行信息時,將之前的所有行信息都傳遞過來進行識別輸出,有效利用了上下文信息[3-4]。但任一行對當前行的影響概率卻沒有明顯差別,這是不合理的。當前行和與當前行聯(lián)系緊密的行信息顯然應(yīng)該具有更大的權(quán)值,注意力機制通過按信息關(guān)聯(lián)的強度分配不同權(quán)重的方法,可以解決這個問題。
綜上所述,本文設(shè)計了基于LSTM和注意力機制的Minist手寫數(shù)字識別算法。手寫數(shù)字信息保存為28×28的矩陣,每張圖片按行輸入LSTM網(wǎng)絡(luò),通過注意力機制調(diào)節(jié)權(quán)值來確定輸入的所有行信息對當前輸出的影響概率。
RNN是循環(huán)神經(jīng)網(wǎng)絡(luò),簡單的RNN結(jié)構(gòu)如圖1所示,包含一個輸入層、一個隱藏層、一個輸出層。權(quán)重矩陣W就是隱藏層上一次的值作為這一次的輸入的權(quán)重。
將圖1的循環(huán)層按時間步展開的結(jié)構(gòu)如圖2所示。圖中,Xt為當前時刻的輸入,Ot為當前時刻的輸出,St為隱藏層的當前值。
圖2 RNN權(quán)值
RNN在任意時刻的神經(jīng)元結(jié)構(gòu)都是相同的。不僅如此,其在不同時刻傳遞時的對應(yīng)位置的權(quán)值也是共享的,圖中不同時刻的權(quán)值W、U、V采用的都是同一矩陣,其意義也是顯而易見的,即在前面信息中學習到的特征可以移植給后面的網(wǎng)絡(luò)直接使用。公式如下:
RNN在反向傳播時面臨著梯度消失和梯度爆炸的問題,而且對于相當長度的前文信息來說,其有效性大大降低。LSTM解決了這幾個問題,其結(jié)構(gòu)如圖3所示。LSTM由遺忘門、輸入門和輸出門三個控制門組成。遺忘門控制上一時刻的單元狀態(tài)Ct-1有多少保留到當前狀態(tài)Ct,輸入門控制當前時刻的網(wǎng)絡(luò)輸入Xt有多少保存到單元狀態(tài)Ct,輸出門控制單元狀態(tài)Ct有多少輸出到LSTM網(wǎng)絡(luò)的當前輸出ht。圖中σ表示sigmoid函數(shù),其取值范圍是[0-1],決定了門控制器能夠通過信息的比例。sigmoid取值為1時,表示所有信息都能通過,完全保留這一分支的記憶,取值為0時,表示沒有信息能夠通過,即所有信息全部遺忘[5-6]。LSTM網(wǎng)絡(luò)的主要計算公式如下:
圖3 LSTM結(jié)構(gòu)
人類的注意力機制能夠利用有限的視覺信息處理資源,從大量信息中獲取有價值的信息,極大地提高了視覺處理的效率。深度學習中的注意力機制受人類視覺注意力啟發(fā),能夠從眾多信息中抽選出對當前任務(wù)目標更為關(guān)鍵的信息。在Bahdanau等首次在機器翻譯中引入注意力機制,并取得不錯的效果之后,其在CNN(Convolutional Neural Network,卷積神經(jīng)網(wǎng)絡(luò))抽取圖像特征、RNN抽取序列信息特征等任務(wù)中都有廣泛的應(yīng)用[7-9]。
在深度學習中,注意力機制可以借助重要性權(quán)重向量來實現(xiàn)。在預測或推斷目標值時,例如文本翻譯中詞與詞之間的聯(lián)系,可以用注意力向量來判斷當前輸出詞與其他詞的關(guān)聯(lián)強度,然后對加權(quán)后的向量求和以逼近正確的標簽值。簡單來說,注意力機制就是分配權(quán)重,例如英文句子“She is wearing a red dress.”中,單詞“wearing”和“dress”屬于強相關(guān)關(guān)系,“is”和“dress”屬于弱相關(guān)關(guān)系,注意力機制在預測“dress”時,就會給“wearing”賦予較高權(quán)重,給“is”賦予較低權(quán)重。
本實驗的算法設(shè)計和實驗流程如圖4所示。首先下載Minist數(shù)據(jù)集,將輸入數(shù)據(jù)X保存為28×28的矩陣并做歸一化處理,標簽數(shù)據(jù)Y轉(zhuǎn)化為獨熱編碼表示。然后通過Keras搭建LSTM網(wǎng)絡(luò),加入注意力機制層,最后將訓練集按epoch喂入網(wǎng)絡(luò)進行參數(shù)訓練,并通過測試集測試訓練效果。
圖4 手寫數(shù)字識別流程圖
訓練集設(shè)置了10個epoch,為防止過擬合,設(shè)置了dropout率為0.25,實驗最終準確率為0.984,測試集準確率為0.9878。為了對比,將注意力機制層去掉,僅使用LSTM網(wǎng)絡(luò)進行訓練和測試,訓練集經(jīng)過10個epoch后,準確率為0.9599,20個epoch后為0.9719,測試集準確率為0.9789。對比實驗結(jié)果發(fā)現(xiàn),在損失率和準確率的表現(xiàn)上,注意力機制的作用效果都很明顯。兩次實驗結(jié)果如下:
LSTM和注意力機制都是當前研究的熱點,有廣闊的發(fā)展前景。相對于傳統(tǒng)的深度識別算法,循環(huán)神經(jīng)網(wǎng)絡(luò)能夠處理序列數(shù)據(jù)信息的上下文關(guān)系,LSTM又改進了普通RNN模型的長時依賴以及梯度消失和梯度爆炸問題。在上下文信息的依賴關(guān)系上,由注意力機制分配權(quán)重以保證最有價值的輸入數(shù)據(jù)影響最終輸出結(jié)果。實驗表明,LSTM結(jié)合注意力機制模型的識別率效果非常好。本實驗將進一步挖掘該模型的應(yīng)用領(lǐng)域,在序列信息處理時,例如文本的上下文、視頻上下幀的分析和預測等,能夠發(fā)揮LSTM和注意力機制的強大優(yōu)勢,取得滿意的應(yīng)用效果。