摘 要:離散序列生成廣泛應(yīng)用于文本生成、序列推薦等領(lǐng)域。目前的研究工作主要集中在提高序列生成的準(zhǔn)確性,卻忽略了生成的多樣性。針對(duì)該現(xiàn)象,提出了一種自適應(yīng)序列生成方法ECoT,設(shè)置兩層元控制器,在數(shù)據(jù)層面,使用元控制器實(shí)現(xiàn)自適應(yīng)可學(xué)習(xí)采樣,自動(dòng)平衡真實(shí)數(shù)據(jù)與生成數(shù)據(jù)分布得到混合數(shù)據(jù)分布;在模型層面,添加多樣性約束項(xiàng),并使用元控制器自適應(yīng)學(xué)習(xí)最優(yōu)更新梯度,提升生成模型生成多樣性。此外,進(jìn)一步提出融合協(xié)同訓(xùn)練和對(duì)抗學(xué)習(xí)的方法,提升生成模型生成準(zhǔn)確性。與目前的主流模型進(jìn)行對(duì)比實(shí)驗(yàn),結(jié)果表明,在生成準(zhǔn)確性和多樣性上,自適應(yīng)協(xié)同訓(xùn)練序列生成方法具有更均衡的準(zhǔn)確性和多樣性,同時(shí)有效緩解了生成模型的模式崩潰問題。
關(guān)鍵詞:深度學(xué)習(xí);機(jī)器學(xué)習(xí);序列生成;協(xié)同訓(xùn)練;對(duì)抗學(xué)習(xí)
中圖分類號(hào):TP391 文獻(xiàn)標(biāo)志碼:A
文章編號(hào):1001-3695(2022)07-025-2081-06
doi:10.19734/j.issn.1001-3695.2021.12.0681
基金項(xiàng)目:國(guó)家社會(huì)科學(xué)基金重大資助項(xiàng)目(13amp;ZD091,18ZDA200);河北省重點(diǎn)研發(fā)計(jì)劃項(xiàng)目(20370301D);河北師范大學(xué)重大關(guān)鍵技術(shù)攻關(guān)項(xiàng)目(L2020K01)
作者簡(jiǎn)介:張寶奇(1996-),男,河北承德人,碩士研究生,主要研究方向?yàn)闄C(jī)器學(xué)習(xí)、智能信息處理;趙書良(1967-),男(通信作者),河北滄州人,教授,博導(dǎo),主要研究方向?yàn)闄C(jī)器學(xué)習(xí)、智能信息處理(zhaoshuliang@sina.com);張劍(1991-),男,河北石家莊人,碩士研究生,主要研究方向?yàn)闄C(jī)器學(xué)習(xí)、智能信息處理;呂曉鋒(1996-),男,河北石家莊人,碩士研究生,主要研究方向?yàn)闄C(jī)器學(xué)習(xí)、智能信息處理.
Sequence generation method based on adaptive learning
Zhang Baoqia,b,c,Zhao Shulianga,b,c?,Zhang Jiana,b,c,Lyu Xiaofenga,b,c
(a.College of Computer amp; Cyber Security,b.Hebei Provincial Engineering Research Center for Supply Chain Big Data Analytics amp; Data Security,c.Hebei Provincial Key Laboratory of Network amp; Information Security,Hebei Normal University,Shijiazhuang 050024,China)
Abstract:Discrete sequence generation is widely used in text generation,sequence recommendation and other fields.The current research work mainly focuses on improving the accuracy of sequence generation,but ignores the diversity of generation.To address this phenomenon,this paper proposed an adaptive sequence generation method (ECoT),and designed a two-layer meta controller.In the data layer,the function of meta controller was to realize adaptive learning sampling,automatically balance the distribution of real data and generated data,and obtain mixed data distribution.At the model level,this paper added diversity constraints.The function of the meta controller was to adaptively learn the optimal update gradient to improve the generation diversity of the generation model.In addition,in order to improve the accuracy of the generation model,this paper proposed a method combining cooperative training and adversarial learning.Compared with the current mainstream models,the results show that the adaptive cooperative training sequence generation method has more balanced accuracy and diversity in terms of generation accuracy and diversity,and can effectively alleviate the pattern collapse of the generation model.
Key words:deep learning;machine learning;sequence generation;cooperative training;adversarial learning
0 引言
序列生成模型廣泛應(yīng)用于自然語言生成(natural language generation,NLG)任務(wù)[1~6]、推薦系統(tǒng)(recommendation system,RS)[7~10]等諸多領(lǐng)域。基于極大似然估計(jì)(maximum likelihood estimation,MLE)的神經(jīng)網(wǎng)絡(luò)模型是序列生成的基本方法,繼神經(jīng)網(wǎng)絡(luò)反向傳播算法出現(xiàn)之后,基于極大似然估計(jì)的前饋神經(jīng)網(wǎng)絡(luò)和循環(huán)神經(jīng)網(wǎng)絡(luò)的序列生成,雖然可以得到與訓(xùn)練集樣本相似的序列數(shù)據(jù),但基于極大似然估計(jì)方法生成的樣本在質(zhì)量上良莠不齊。首先,在模型訓(xùn)練階段,生成序列每一步所需的序列項(xiàng)均是來自于訓(xùn)練集,測(cè)試階段輸入序列項(xiàng)均是由生成模型生成而來,使得訓(xùn)練階段未曾發(fā)現(xiàn)的錯(cuò)誤學(xué)習(xí)結(jié)果在測(cè)試階段快速積累而導(dǎo)致曝光偏差 (exposure bias)問題[11]。其次,序列生成過程中使用的數(shù)據(jù)均為離散數(shù)據(jù),且每次只生成一個(gè)序列項(xiàng),訓(xùn)練過程選用交叉熵函數(shù)作為模型損失函數(shù),BLEU等離散型評(píng)價(jià)指標(biāo)對(duì)模型進(jìn)行評(píng)價(jià),結(jié)果表明,各種離散型指標(biāo)的不可微性最終導(dǎo)致了損失函數(shù)優(yōu)化難的問題。隨后,生成對(duì)抗網(wǎng)絡(luò)(generative adversarial networks,GAN)在連續(xù)數(shù)據(jù)生成領(lǐng)域展現(xiàn)了強(qiáng)大的性能,并擁有可靠的理論基礎(chǔ),逐漸成為流行的生成模型之一[2]。該模型基于零和博弈的思想,生成器生成偽數(shù)據(jù)去混淆判別器,使得判別器無法判斷輸入數(shù)據(jù)是否來自于訓(xùn)練集;判別器為了能夠正確判別數(shù)據(jù)真?zhèn)芜M(jìn)行不斷訓(xùn)練,然后判別器通過訓(xùn)練將學(xué)習(xí)到的知識(shí)通過梯度的形式傳遞給生成器,引導(dǎo)生成器進(jìn)行訓(xùn)練。但該模型在序列離散數(shù)據(jù)生成上仍具有目標(biāo)函數(shù)不可微、優(yōu)化難等問題。
現(xiàn)階段,基于上述問題提出了兩類解決方法,即強(qiáng)化學(xué)習(xí)和策略梯度方法、改進(jìn)的生成對(duì)抗網(wǎng)絡(luò)方法。在第一類強(qiáng)化學(xué)習(xí)和策略梯度方法中,最為經(jīng)典的是序列生成對(duì)抗網(wǎng)絡(luò)模型(sequence generative adversarial networks,SeqGAN)[12]。該模型使用策略梯度算法優(yōu)化目標(biāo),在很多場(chǎng)景下都有出色的表現(xiàn),但該模型在不同場(chǎng)景下生成結(jié)果具有不穩(wěn)定、不可靠的缺點(diǎn)。第二類方法中,主要使用Gumbel-softmax[13],得到一個(gè)近似連續(xù)分布的序列離散數(shù)據(jù)分布,使得模型在訓(xùn)練過程中的目標(biāo)函數(shù)是可微的,該方法極大地增強(qiáng)了訓(xùn)練過程的穩(wěn)定性[13]。上述各種生成模型算法中,其更多地注重對(duì)模型生成準(zhǔn)確度的改進(jìn),卻忽視了模型生成的多樣性。為此,提出一種協(xié)同訓(xùn)練方法,即離散數(shù)據(jù)生成的協(xié)同訓(xùn)練(cooperative training,CoT)[14]。CoT模型通過直接優(yōu)化Jensen-Shannon散度來進(jìn)行針對(duì)離散數(shù)據(jù)生成模型的訓(xùn)練,并且改進(jìn)了優(yōu)化過程,引入了中間模型,在一定程度上提升了模型生成數(shù)據(jù)的多樣性,但是仍存在優(yōu)化空間。
大多數(shù)研究者的工作主要集中在連續(xù)數(shù)據(jù)上,而對(duì)于生成序列離散數(shù)據(jù)的研究較少,并且生成式對(duì)抗網(wǎng)絡(luò)在序列生成過程中忽略了生成的多樣性,為此本文提出了自適應(yīng)序列生成方法ECoT,使用兩層元控制器。在數(shù)據(jù)層面平衡真實(shí)數(shù)據(jù)以及生成數(shù)據(jù)分布,第一層元控制器調(diào)節(jié)真實(shí)樣本分布和生成樣本分布的混合程度,得到混合數(shù)據(jù)分布。在模型層面優(yōu)化模型訓(xùn)練的過程,第二層元控制器調(diào)節(jié)生成器更新梯度,進(jìn)而尋找樣本生成質(zhì)量和多樣性之間的最佳平衡,在保證其生成樣本準(zhǔn)確性的同時(shí)最大限度地提高生成的多樣性。
1 相關(guān)工作
1.1 符號(hào)定義
本文中所用到的符號(hào)定義如表1所示。
1.2 序列生成對(duì)抗網(wǎng)絡(luò)
2014年,Goodfellow等人[15]將零和博弈的對(duì)抗學(xué)習(xí)思想與深度學(xué)習(xí)相結(jié)合,提出了生成對(duì)抗網(wǎng)絡(luò)。在生成對(duì)抗網(wǎng)絡(luò)中,生成器Gθ接收真實(shí)樣本數(shù)據(jù)來生成序列,其本質(zhì)是對(duì)序列進(jìn)行特征提取,然后根據(jù)提取到的特征學(xué)習(xí)樣本真實(shí)分布來混淆判別器Dψ。而判別器Dψ的目標(biāo)則是能夠完全區(qū)分?jǐn)?shù)據(jù)的真?zhèn)巍.?dāng)判別器Dψ達(dá)到無法準(zhǔn)確區(qū)分真實(shí)的序列樣本和生成的序列樣本狀態(tài),便是生成對(duì)抗網(wǎng)絡(luò)的理想狀態(tài),這種狀態(tài)被稱之為納什平衡狀態(tài)[15]。生成對(duì)抗網(wǎng)絡(luò)的學(xué)習(xí)過程實(shí)際上是尋找極大極小值問題,生成對(duì)抗網(wǎng)絡(luò)的目標(biāo)函數(shù)如式(1)所示。
對(duì)式(1)進(jìn)行一步的推導(dǎo)可以觀察到,生成對(duì)抗網(wǎng)絡(luò)中生成器Gθ的訓(xùn)練目標(biāo)實(shí)際上是求解Jensen-Shannon散度(JS散度)的最小值。JS散度的定義如式(2)所示。
其中:M12(Pdata+G),系數(shù)12為生成器和真實(shí)數(shù)據(jù)分布的影響力權(quán)重。
為了解決序列數(shù)據(jù)的離散且不可微問題,序列對(duì)抗生成模型SeqGAN引入了強(qiáng)化學(xué)習(xí)的思想。該模型將強(qiáng)化學(xué)習(xí)中的隨機(jī)策略運(yùn)用到生成對(duì)抗網(wǎng)絡(luò)中,通過策略梯度更新模型的參數(shù)來解決離散數(shù)據(jù)生成過程中的不可微問題。在SeqGAN中,將生成器模型的目標(biāo)函數(shù)進(jìn)行了改進(jìn),具體如式(3)所示。
其中:s表示生成器所生成的一個(gè)完整序列;Qt(st,xt)是狀態(tài)—?jiǎng)幼髦岛瘮?shù),表示從狀態(tài)st開始,采取行動(dòng)xt的累計(jì)獎(jiǎng)勵(lì)。SeqGAN將判別器對(duì)生成器生成序列的評(píng)分作為強(qiáng)化學(xué)習(xí)中的獎(jiǎng)勵(lì)。為了解決訓(xùn)練過程中存在的曝光偏差問題,在生成器訓(xùn)練過程中,使用蒙特卡羅搜索和roll-out策略對(duì)未確定的T-t個(gè)序列項(xiàng)進(jìn)行采樣,來評(píng)估中間狀態(tài)的動(dòng)作值。狀態(tài)—?jiǎng)幼髦岛瘮?shù)定義如式(4)所示。
SeqGAN模型的判別器,其目標(biāo)函數(shù)如式(5)所示。
由于SeqGAN對(duì)于強(qiáng)化學(xué)習(xí)思想的依賴,導(dǎo)致了生成對(duì)抗網(wǎng)絡(luò)中的模式崩潰問題更為嚴(yán)重。換而言之,SeqGAN雖然獲得了與訓(xùn)練集真實(shí)數(shù)據(jù)概率分布具有較高相似度的生成概率分布,但是SeqGAN的高準(zhǔn)確度是以犧牲生成數(shù)據(jù)的多樣性作為代價(jià)。
1.3 協(xié)同訓(xùn)練生成模型
SeqGAN模型中,生成器在對(duì)抗學(xué)習(xí)過程中,隨著訓(xùn)練次數(shù)的增多,其生成樣本多樣性逐漸減小。針對(duì)SeqGAN模型存在生成數(shù)據(jù)缺少多樣性的問題,CoT模型引入了中間協(xié)同訓(xùn)練模型M來引導(dǎo)協(xié)助生成器的訓(xùn)練。該模型的最終目標(biāo)函數(shù)如式(6)所示。該模型將生成對(duì)抗網(wǎng)絡(luò)中的最大最小化目標(biāo)函數(shù)轉(zhuǎn)換成了完全的最大化問題。
2 自適應(yīng)序列生成方法
在生成對(duì)抗學(xué)習(xí)中引入強(qiáng)化學(xué)習(xí)方法,盡管在生成準(zhǔn)確度上有了明顯的效果提升,但是生成對(duì)抗網(wǎng)絡(luò)在對(duì)抗訓(xùn)練階段時(shí)常會(huì)出現(xiàn)模式崩潰現(xiàn)象。在CoT模型提出后,模式崩潰問題得到了緩解。本章將通過建立元學(xué)習(xí)任務(wù)來進(jìn)一步緩解模式崩潰問題,尋找生成器生成高質(zhì)量序列和多樣性序列之間的最佳平衡,在保證準(zhǔn)確率的同時(shí),提升生成樣本多樣性。為此,本文分別在中間輔助生成模型訓(xùn)練階段和生成器訓(xùn)練階段中設(shè)置M-元控制器和G-元控制器兩個(gè)元控制器。其中,M-元控制器在數(shù)據(jù)層面控制輸入到中間輔助生成器M的生成樣本和真實(shí)樣本與的混合程度,G-元控制器在模型層面控制生成器訓(xùn)練過程梯度的更新。ECoT模型的整體框架如圖1所示。
2.1 M-元控制器
中間輔助生成模型的訓(xùn)練階段,中間輔助生成模型由生成模型和真實(shí)樣本分布的混合概率分布函數(shù)組成。其中,12(Pdata+Gθ)表示混合數(shù)據(jù)分布M*,將生成器和訓(xùn)練集的采樣結(jié)果作為中間輔助生成器的輸入,這種方式在一定程度上緩解了曝光偏差問題[11,13]。CoT模型中,中間輔助生成器的目標(biāo)函數(shù)如式(7)所示。
其中:使用來自生成器的樣本計(jì)算生成模型與中間輔助生成模型的KL散度值,使用來自訓(xùn)練集的樣本計(jì)算真實(shí)數(shù)據(jù)分布Pdata與中間輔助生成模型M?的KL散度值,最后計(jì)算兩個(gè)KL散度平均值,式(7)實(shí)質(zhì)上是混合數(shù)據(jù)分布和中間輔助生成器對(duì)應(yīng)分布的KL散度。模型訓(xùn)練過程中,中間輔助生成模型擬合的分布函數(shù)向真實(shí)數(shù)據(jù)分布靠攏,同時(shí)使用生成器來控制中間輔助生成模型每次向真實(shí)數(shù)據(jù)分布靠攏的程度。式(7)中將生成器和真實(shí)數(shù)據(jù)分布視為同等影響力的設(shè)置,對(duì)生成多樣性的提高具有局限性。針對(duì)該問題,將學(xué)習(xí)動(dòng)態(tài)權(quán)重設(shè)置為第一元學(xué)習(xí)任務(wù),也是M-元控制器的主要內(nèi)容。
在本文中,從數(shù)據(jù)采樣的角度分析中間輔助生成器對(duì)生成器生成多樣性的影響。在多數(shù)生成對(duì)抗網(wǎng)絡(luò)模型中對(duì)真實(shí)數(shù)據(jù)、生成數(shù)據(jù)均采用等比例分層采樣的方法,采樣出等量的真實(shí)樣本和生成樣本,這種做法存在兩個(gè)問題:a)在訓(xùn)練初期,生成器并不具備良好的生成效果,所以生成的數(shù)據(jù)對(duì)中間輔助器的訓(xùn)練協(xié)同互助的作用有限;b)在訓(xùn)練后期,生成器生成與真實(shí)數(shù)據(jù)十分相似的數(shù)據(jù),該期間中間輔助生成器需要增加生成的多樣性,但是由于生成樣本與真實(shí)樣本的等比例分層采樣出的輸入數(shù)據(jù)無法為生成器提供更多生成多樣性學(xué)習(xí)的引導(dǎo)。對(duì)于上述兩個(gè)問題,本文利用M-元控制器學(xué)習(xí)長(zhǎng)尾分布,控制中間輔助生成器輸入數(shù)據(jù)的混合程度,有側(cè)重地為中間輔助生成器提供訓(xùn)練樣本,過程如圖2所示。M-元控制器的實(shí)現(xiàn)如式(8)所示,中間輔助生成器的目標(biāo)函數(shù)如式(9)所示。
其中:λM值由式(10)所得。
M-元控制器在數(shù)據(jù)層面動(dòng)態(tài)調(diào)控真實(shí)樣本分布和生成樣本分布對(duì)中間輔助生成器模型訓(xùn)練的影響。在M-元控制器的調(diào)節(jié)下,真實(shí)數(shù)據(jù)分布以及生成數(shù)據(jù)分布動(dòng)態(tài)引導(dǎo)中間輔助生成模型訓(xùn)練,為中間輔助生成器多樣性的提高隱式地提供了方向。式(10)的MLP選擇sigmoid作為輸出單元激活函數(shù)。
2.2 G-元控制器
中間輔助生成器訓(xùn)練后轉(zhuǎn)而進(jìn)行生成器的訓(xùn)練,現(xiàn)有的方法均是通過最小化JS散度作為生成器目標(biāo)函數(shù),具體如式(11)所示。
最小化JS散度實(shí)際上是通過中間輔助生成器間接引導(dǎo)生成器Gθ訓(xùn)練,具體過程如圖3(a)~(d)所示。假設(shè)真實(shí)數(shù)據(jù)服從高斯分布的概率密度函數(shù)擬合,其中“--”為
ECoT模型生成器的概率密度函數(shù)
,“··”為真實(shí)數(shù)據(jù)概率密度函數(shù),“__”為中間輔助生成器的概率密度函數(shù)。中間輔助生成器指導(dǎo)訓(xùn)練的過程中隱式地提升生成器的生成多樣性,但是這種隱式提升生成多樣性的方法仍存在局限性,對(duì)多樣性的提高具有不確定性。為了改進(jìn)這種情況,同時(shí)緩解對(duì)抗模型中常見的模式崩潰問題,本文添加了多樣性約束項(xiàng),該約束項(xiàng)的目的是提供可顯式優(yōu)化損失部分,在目標(biāo)函數(shù)中體現(xiàn)生成多樣性的訓(xùn)練,該約束項(xiàng)指導(dǎo)生成器逼近均勻分布,進(jìn)而提升生成的多樣性,在生成器的訓(xùn)練過程中添加更多的隨機(jī)性。
至此,生成器的目標(biāo)函數(shù)已經(jīng)具備可以顯式優(yōu)化生成準(zhǔn)確性和多樣性的能力。但是準(zhǔn)確性目標(biāo)項(xiàng)和多樣性目標(biāo)項(xiàng)彼此是相互對(duì)立、相互競(jìng)爭(zhēng)的關(guān)系,在訓(xùn)練過程中難以達(dá)到預(yù)期的效果,甚至?xí)?dǎo)致目標(biāo)函數(shù)難以收斂。為了解決該問題,設(shè)置G-元控制器顯式地控制生成器生成準(zhǔn)確性以及多樣性的訓(xùn)練與學(xué)習(xí),生成器的最終目標(biāo)函數(shù)如式(12)所示。
其中: U表示均勻分布U(0,SG);SG表示訓(xùn)練集樣本數(shù)量;Cs表示訓(xùn)練數(shù)據(jù)集中樣本種類數(shù);λG由式(13)所得。
其中:D(·)表示判別器的判別結(jié)果;ξ表示判別器質(zhì)量系數(shù)。
G-元控制器控制目標(biāo)函數(shù)中準(zhǔn)確性目標(biāo)項(xiàng)和多樣性目標(biāo)項(xiàng)對(duì)生成器訓(xùn)練的影響,進(jìn)而指導(dǎo)訓(xùn)練的方向。接下來,本文從梯度更新角度對(duì)多樣性訓(xùn)練進(jìn)行分析,設(shè)FA表示目標(biāo)函數(shù)中的準(zhǔn)確性目標(biāo)項(xiàng),F(xiàn)D表示目標(biāo)函數(shù)中多樣性目標(biāo)項(xiàng)。在訓(xùn)練初期,隨機(jī)初始化后,準(zhǔn)確性目標(biāo)項(xiàng)的值偏低,多樣性占據(jù)優(yōu)勢(shì),G-元控制器在該階段提高準(zhǔn)確性目標(biāo)項(xiàng)的值,即增加FA的模長(zhǎng),減少FD的模長(zhǎng),如圖4(a)所示;而且在模型的最低點(diǎn)附近,準(zhǔn)確性目標(biāo)項(xiàng)展現(xiàn)優(yōu)勢(shì),G-元控制器減少FA的模長(zhǎng),增加FD的模長(zhǎng),如圖4(b)所示。通過對(duì)梯度下降方向的控制,G-元控制器在模型層面尋找最佳更新梯度。尋找準(zhǔn)確性和多樣性的最優(yōu)解,以緩解模式崩潰問題。
為了進(jìn)一步保證生成模型的準(zhǔn)確性,防止由于過分追求生成多樣性導(dǎo)致的生成準(zhǔn)確性丟失,本文設(shè)置對(duì)抗判別器Dψ。對(duì)抗判別器主要用于輔助生成器矯正準(zhǔn)確性訓(xùn)練方向和為G-元控制器提供先驗(yàn)知識(shí)輸入。判別器的目標(biāo)函數(shù)如式(14)所示,判別式的輸出表示判別樣本為訓(xùn)練集真實(shí)樣本的概率值。
在對(duì)抗學(xué)習(xí)過程,生成器選用交叉熵作為損失函數(shù),專注于準(zhǔn)確性的提升。在對(duì)抗學(xué)習(xí)結(jié)束后,判別器的輸出和判別器的交叉熵?fù)p失值組合作為G-元控制器輸入。判別器輸出對(duì)輸入樣本的判別概率,由于判別器處于動(dòng)態(tài)訓(xùn)練中,所以需要計(jì)算判別器的判別質(zhì)量,稱為判別器質(zhì)量系數(shù),如式(15)所示,判別質(zhì)量系數(shù)與判別器對(duì)生成樣本的判別輸出向量進(jìn)行元素相乘,作為G-元控制器部分輸入。
其中:loss表示判別器的交叉熵?fù)p失值。
使用M-元控制器和G-元控制器對(duì)生成器Gθ和中間輔助生成器M?的學(xué)習(xí)過程進(jìn)行兩層控制調(diào)節(jié),在數(shù)據(jù)和模型兩個(gè)層面進(jìn)行學(xué)習(xí),元控制器優(yōu)化生成器Gθ和中間輔助生成器M?向真實(shí)分布學(xué)習(xí)的過程,使模型能夠在保證生成樣本準(zhǔn)確度的同時(shí),提升模型生成樣本的多樣性,同時(shí),增加對(duì)抗學(xué)習(xí)器以保證生成器準(zhǔn)確率的穩(wěn)定訓(xùn)練。算法1中給出了自適應(yīng)序列生成方法的完整算法過程。
算法1 自適應(yīng)序列生成方法
輸入: Gθ,Dψ,M?;從真實(shí)樣本分布Pdata采樣樣本。
輸出:生成器Gθ。
initialize Gθ,Dψ,M?,λM,λG with random weights θ,ψ,?,ωM,ωG
pretrain Gθ with samples from Pdata
while not done do
for Nm steps do
sample sg from Gθ and" Sp from Pdata
compute loss of" M? and λM
update ?,ωM
end for
samples S from Gθ
compute loss of Gθ and λG
update θ,ωG
end while
3 實(shí)驗(yàn)
本章使用簡(jiǎn)寫ECoT表示自適應(yīng)序列生成方法。首先,本文分別在SeqGAN[12]和Meta-CoTGAN[1]中引入的合成離散序列生成數(shù)據(jù)集以及常用的文本生成數(shù)據(jù)集COCO圖像字幕數(shù)據(jù)集和EMNLP2017 WMT News數(shù)據(jù)集上進(jìn)行對(duì)比實(shí)驗(yàn)。在三個(gè)數(shù)據(jù)集的實(shí)驗(yàn)中,均使用TensorFlow深度學(xué)習(xí)框架和Texygen框架進(jìn)行訓(xùn)練與模型的評(píng)估[16],在本文實(shí)驗(yàn)環(huán)節(jié)中,將生成對(duì)抗網(wǎng)絡(luò)中的對(duì)抗學(xué)習(xí)過程和協(xié)同訓(xùn)練的過程相結(jié)合,設(shè)置生成器Gθ、判別器Dψ、中間輔助生成模型M?,設(shè)置判別器的目標(biāo)函數(shù)為JD(ψ)。
3.1 實(shí)驗(yàn)評(píng)價(jià)指標(biāo)
在實(shí)驗(yàn)中,從生成樣本的準(zhǔn)確性以及生成樣本的多樣性兩個(gè)方面對(duì)生成數(shù)據(jù)進(jìn)行評(píng)估?,F(xiàn)有的大部分文本序列生成工作使用BLEU分?jǐn)?shù)度量生成器在真實(shí)數(shù)據(jù)集上學(xué)習(xí)后所生成的樣本質(zhì)量,而在合成數(shù)據(jù)集上對(duì)模型進(jìn)行評(píng)估時(shí),通過計(jì)算
值來評(píng)估生成樣本準(zhǔn)確性,使用NLLtest評(píng)估生成樣本多樣性[1]。設(shè)計(jì)圖靈測(cè)試任務(wù),依據(jù)專家網(wǎng)絡(luò)提供的先驗(yàn)知識(shí)計(jì)算負(fù)對(duì)數(shù)似然函數(shù)NLLoracle,其具體計(jì)算如式(16)所示,當(dāng)NLLoracle越小時(shí),表示生成模型的準(zhǔn)確度越高。
NLLtest指的是從專家網(wǎng)絡(luò)額外抽取樣本,計(jì)算生成器的負(fù)對(duì)數(shù)似然,NLLtest是用于評(píng)估模型擬合真實(shí)測(cè)試數(shù)據(jù)能力的簡(jiǎn)單指標(biāo)[16],通過評(píng)估生成模型在真實(shí)數(shù)據(jù)概率密度上的覆蓋范圍來評(píng)估生成樣本的多樣性以及生成模型的抗模式崩潰能力。如果模型在真實(shí)數(shù)據(jù)空間中具有更廣泛的覆蓋范圍,則生成的樣本將具有更好的多樣性,對(duì)應(yīng)的損失會(huì)更低。相反地,如果模型存在嚴(yán)重的模式崩潰問題,那么模型在真實(shí)數(shù)據(jù)空間中覆蓋范圍較小,模型將不能很好地代表真實(shí)數(shù)據(jù),并且會(huì)得到較高的損失[1]。其具體計(jì)算如式(17)所示。
BLEU最初被應(yīng)用于機(jī)器翻譯系統(tǒng),是用來衡量機(jī)器翻譯的結(jié)果和人工翻譯的差異的指標(biāo)。假設(shè)輸入標(biāo)準(zhǔn)的人工翻譯結(jié)果,生成器模型生成相應(yīng)的翻譯結(jié)果,將句子長(zhǎng)度設(shè)為n,在生成器模型生成的翻譯結(jié)果中存在m個(gè)單詞是在標(biāo)準(zhǔn)的人工翻譯結(jié)果中重復(fù)出現(xiàn)的,那么稱得到的m/n就是BLEU的1-gram值。通過這種計(jì)算方法來衡量生成結(jié)果的準(zhǔn)確性。根據(jù)k-gram中k的取值不同,BLEU-k的結(jié)果也有所不同,更高階的BLEU衡量生成序列的準(zhǔn)確性,同時(shí)衡量生成序列的流暢性。
3.2 實(shí)驗(yàn)數(shù)據(jù)與結(jié)果分析
3.2.1 專家網(wǎng)絡(luò)合成數(shù)據(jù)集
本文使用已訓(xùn)練好的長(zhǎng)短期記憶網(wǎng)絡(luò)(long short-term memory,LSTM)模型作為專家網(wǎng)絡(luò),其生成的數(shù)據(jù)作為訓(xùn)練集,專家網(wǎng)絡(luò)不僅為模型的訓(xùn)練提供訓(xùn)練數(shù)據(jù)集,還提供訓(xùn)練樣本的先驗(yàn)知識(shí)。在實(shí)驗(yàn)部分,優(yōu)先在專家網(wǎng)絡(luò)上合成的數(shù)據(jù)集進(jìn)行模型評(píng)價(jià),在SeqGAN模型實(shí)驗(yàn)中首次使用專家網(wǎng)絡(luò)合成的數(shù)據(jù)集。實(shí)驗(yàn)中設(shè)置專家網(wǎng)絡(luò)模擬現(xiàn)實(shí)世界中的序列數(shù)據(jù),生成長(zhǎng)度為20并且序列項(xiàng)總數(shù)為5 000的訓(xùn)練數(shù)據(jù),總共生成10 000個(gè)序列用做訓(xùn)練。訓(xùn)練前,對(duì)所有生成器的參數(shù)進(jìn)行初始化,且參數(shù)使用正態(tài)分布初始化器初始化,且在所有的生成器預(yù)訓(xùn)練階段均選擇極大似然估計(jì)作為預(yù)訓(xùn)練過程[7],選擇LSTM作為生成器與判別器模型,Meta-CoTGAN溫度設(shè)置為1 000。在預(yù)訓(xùn)練階段中,本文首先訓(xùn)練生成器80個(gè)epoch,然后訓(xùn)練判別器80個(gè)epoch,隨后進(jìn)入到對(duì)抗訓(xùn)練階段。每次對(duì)抗階段,本文更新一次生成器后,判別器進(jìn)行15次小批量梯度更新[7]。此外,在LeakGAN模型的訓(xùn)練過程中,每10次對(duì)抗訓(xùn)練后,生成器和判別器將進(jìn)行5次極大似然估計(jì)訓(xùn)練[7]。本次實(shí)驗(yàn)除了將ECoT與CoT及Meta-CotGAN模型等協(xié)同訓(xùn)練模型進(jìn)行效果對(duì)比實(shí)驗(yàn)之外,還進(jìn)行了同極大似然估計(jì)模型MLE以及融合了強(qiáng)化學(xué)習(xí)思想的序列生成對(duì)抗網(wǎng)絡(luò)SeqGAN及其變體(如MaliGAN[11]、RankGAN[17]、LeakGAN[18]模型等)實(shí)驗(yàn)結(jié)果的對(duì)比展示。
在表2和圖5中,將ECoT與MLE、GAN及其強(qiáng)化學(xué)習(xí)變體、CoT模型及其變體在NLLoracle和NLLtest兩個(gè)指標(biāo)上的結(jié)果進(jìn)行比較展示。在圖5中可以清楚地發(fā)現(xiàn),采用MLE作為基礎(chǔ)生成器的CoT模型和極大似然估計(jì)MLE模型生成樣本的準(zhǔn)確率是相對(duì)最低的,NLLoracle值達(dá)到了9以上,而SeqGAN模型及其引入強(qiáng)化學(xué)習(xí)變體的模型方法,如RankGAN、MaliGAN、LeakGAN的NLLoracle值分別為8.74、8.40、8.91、8.57,整體取值在8.4~8.9浮動(dòng)。對(duì)比NLLtest值可以發(fā)現(xiàn),LeakGAN的NLLtest值最小為4.54,也就是說在SeqGAN模型引入強(qiáng)化學(xué)習(xí)變體的模型方法中,LeakGAN的準(zhǔn)確率及多樣性是最高的。采用LeakGAN模型的架構(gòu)作為基礎(chǔ)生成器架構(gòu)的CoT-Strong模型,其NLLoracle值減小到了8.24,NLLtest減小到4.36。在梯度中建立元學(xué)習(xí)任務(wù)的Meta-CoTGAN的NLLoracle值為8.18,NLLtest值減小到了4.30。本文提出的ECoT模型的NLLoracle、NLLtest值分別為7.57、4.12,其中NLLtest值較Meta-CoTGAN降低了4.1%,而NLLoracle相比于Meta-CoTGAN更是降低了7.5%。由上述實(shí)驗(yàn)數(shù)據(jù)可以看到,本文提出的ECoT模型在生成序列的準(zhǔn)確度上得到了提升,甚至達(dá)到了最優(yōu),并且對(duì)比NLLtest值,實(shí)驗(yàn)結(jié)果表明本文提出的ECoT模型在生成樣本的多樣性上具有更好的效果,抗模式崩潰能力更強(qiáng)。通過兩個(gè)方面的實(shí)驗(yàn)對(duì)比可以發(fā)現(xiàn),本文方法在生成的準(zhǔn)確度以及生成樣本的多樣性上均達(dá)到了最優(yōu)。
3.2.2 COCO圖像字幕數(shù)據(jù)集和EMNLP2017 WMT News數(shù)據(jù)集
COCO(common objects in context)起源于微軟于2014年出資標(biāo)注的Microsoft COCO數(shù)據(jù)集[19]。COCO數(shù)據(jù)集中的圖像分為訓(xùn)練集、驗(yàn)證集和測(cè)試集。COCO數(shù)據(jù)集是一個(gè)大型的、豐富的物體檢測(cè)、分割和字幕數(shù)據(jù)集。COCO數(shù)據(jù)集以對(duì)圖片中場(chǎng)景的理解作為其主要目標(biāo),工作的具體內(nèi)容是從復(fù)雜的日常場(chǎng)景中截取圖像中的目標(biāo),通過精確的語義分割進(jìn)行位置標(biāo)定。圖片包括91種對(duì)象目標(biāo),328 000張圖片和2 500 000個(gè)標(biāo)簽。目前為止是語義分割的最大數(shù)據(jù)集,提供的類別有80 類,超過33 萬張圖片,其中20 萬張有標(biāo)注,整個(gè)數(shù)據(jù)集中個(gè)體的數(shù)目超過150 萬個(gè)。該數(shù)據(jù)集包含多組圖像描述對(duì)。本文將整個(gè)圖像數(shù)據(jù)集上的圖像標(biāo)題作為要生成的文本,其中大多數(shù)句子約為10個(gè)單詞。因此,本文對(duì)數(shù)據(jù)集進(jìn)行了一些預(yù)處理。COCO圖像字幕訓(xùn)練數(shù)據(jù)集由20 734個(gè)單詞和417 126個(gè)句子組成。本文刪除頻率低于10的單詞以及包含它們的句子。經(jīng)過預(yù)處理,最終數(shù)據(jù)集中包含了20 000個(gè)句子。EMNLP2017 WMT News數(shù)據(jù)集作為長(zhǎng)文本語料庫進(jìn)行實(shí)驗(yàn)評(píng)估。EMNLP2017 WMT News是機(jī)器翻譯領(lǐng)域最重要的公開數(shù)據(jù)集,其數(shù)據(jù)規(guī)模較大,含有多種語言的文本,通常在百萬句到千萬句不等。這部分實(shí)驗(yàn)中,首先從原始EMNLP2017 WMT News數(shù)據(jù)集中選擇部分新聞數(shù)據(jù)。選擇出來新聞數(shù)據(jù)集由6 459個(gè)單詞和397 726個(gè)句子組成。本文通過消除頻率低于4 050的單詞以及包含這些低頻單詞的句子對(duì)數(shù)據(jù)進(jìn)行預(yù)處理。在本實(shí)驗(yàn)中,刪除了長(zhǎng)度小于20的句子,最終采樣出20 000個(gè)句子作為實(shí)驗(yàn)數(shù)據(jù)集。
為了對(duì)比本文模型同生成對(duì)抗網(wǎng)絡(luò)與強(qiáng)化學(xué)習(xí)的組合在準(zhǔn)確度和多樣性上的提升,驗(yàn)證協(xié)同訓(xùn)練方法比生成對(duì)抗網(wǎng)絡(luò)與強(qiáng)化學(xué)習(xí)組合更加適合序列離散數(shù)據(jù)的生成,在兩個(gè)數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),使用BLEU對(duì)準(zhǔn)確度進(jìn)行評(píng)估,使用NLLtest評(píng)估生成樣本的多樣性。表3展示模型在COCO圖像字幕數(shù)據(jù)集以及EMNLP2017 WMT News數(shù)據(jù)集上的BLEU-2~BLEU-5評(píng)分,用于測(cè)量樣本質(zhì)量的分?jǐn)?shù),
NLLtest值用于評(píng)估生成樣本的多樣性。為了更清楚地展示,使用柱狀圖進(jìn)行實(shí)驗(yàn)結(jié)果對(duì)比,如圖6所示,(a)表示在COCO圖像字幕數(shù)據(jù)集上的結(jié)果對(duì)比,(b)表示在EMNLP2017 WMT News數(shù)據(jù)集上的結(jié)果對(duì)比。在兩個(gè)真實(shí)數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果表明,不論是在準(zhǔn)確性上還是在多樣性上,協(xié)同訓(xùn)練方法在離散數(shù)據(jù)的生成問題上比生成對(duì)抗網(wǎng)絡(luò)與強(qiáng)化學(xué)習(xí)組合更具優(yōu)勢(shì)。ECoT模型同協(xié)同訓(xùn)練方法CoT模型相比,雖然在生成多樣性上與CoT接近,但是在生成準(zhǔn)確性上,ECoT較CoT具有更優(yōu)的表現(xiàn)。實(shí)驗(yàn)結(jié)果表明,ECoT具備尋找準(zhǔn)確性和多樣性最優(yōu)組合的能力。
在表4中,將ECoT模型同MLE、CoT在COCO圖像字幕數(shù)據(jù)集訓(xùn)練采樣結(jié)果進(jìn)行對(duì)比。在表4中能夠直觀發(fā)現(xiàn),本文提出的ECoT模型能夠生成表達(dá)更為人性化的語句,并且更具多樣性;而CoT模型雖然能夠生成表達(dá)多樣的語句序列,但是其生成的句子在正確性上略顯不足;MLE模型作為基礎(chǔ)的序列生成模型,在生成短句上具備更好的性能,但是在更為復(fù)雜的長(zhǎng)句生成中,生成結(jié)果會(huì)出現(xiàn)難以意料的錯(cuò)誤生成結(jié)果,并且MLE模型生成的語句較顯單一,并不具備較好的表達(dá)性。相比較而言,本文提出的ECoT模型在兼顧生成準(zhǔn)確性的同時(shí),能夠生成句式更加豐富,表達(dá)更為多樣性的語句序列。
4 結(jié)束語
本文提出了一種新的方法ECoT模型,以協(xié)調(diào)序列生成的準(zhǔn)確性和多樣性,為提升生成樣本的多樣性提供了一種新思路。在協(xié)同訓(xùn)練思想基礎(chǔ)之上,ECoT分別在中間輔助生成器的訓(xùn)練階段引入了M-元控制器,在生成器的訓(xùn)練階段引入了G-元控制器。通過對(duì)模型施加兩層控制,在訓(xùn)練中尋找接近真實(shí)數(shù)據(jù)分布的同時(shí),又能夠保持其生成樣本多樣性的平衡態(tài)模型,并且本文方法在面對(duì)模式崩潰問題時(shí)展現(xiàn)了有效性。實(shí)驗(yàn)結(jié)果表明,本文方法在準(zhǔn)確性和多樣性上優(yōu)于對(duì)比方法,同時(shí)在與LeakGAN等強(qiáng)化學(xué)習(xí)架構(gòu)相結(jié)合時(shí),本文方法能夠展現(xiàn)出更加優(yōu)秀的性能。在未來的研究中,將使用元學(xué)習(xí)方法與CoT深度結(jié)合,對(duì)比不同生成模型的生成效果,進(jìn)一步提高生成序列的多樣性。
參考文獻(xiàn):
[1]Yin Haiyan,Li Dingcheng,Li Xu,et al.Meta-CoTGAN:a meta coo-perative training paradigm for improving adversarial text generation[C]//Proc of AAAI Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2020:9466-9473.
[2]Bahdanau D,Cho K H,Bengio Y.Neural machine translation by jointly learning to align and translate[C]//Proc of the 3rd International Conference on Learning Representations.2015:1-15.
[3]張涼,楊燕,陳成才,等.基于多視角對(duì)抗學(xué)習(xí)的開放域?qū)υ捝赡P停跩].計(jì)算機(jī)應(yīng)用研究,2021,38(2):372-376.(Zhang Liang,Yang Yan,Chen Chengcai,et al.Open domain dialogue generation model based on multi-view adversarial learning[J].Application Research of Computers,2021,38(2):372-376.)
[4]LiuShuman,Chen Hongshen,Ren Zhaochun,et al.Knowledge diffusion for neural dialogue generation[C]//Proc of the 56th Annual Meeting of the Association for Computational Linguistics.Stroudsburg,PA:Association for Computational Linguistics,2018:1489-1498.
[5]Vaswani A,Shazeer N,Parmar N,et al.Attention is all you need[C]//Proc of the 31st International Conference on Neural Information Processing Systems.Red Hook,NY:Curran Associates Inc.,2017:5998-6008.
[6]Lin Junyang,Sun Xu,Ma Shuangming,et al.Global encoding for abstractive summarization[C]//Proc of the 56th Annual Meeting of the Association for Computational Linguistics.Stroudsburg,PA:Association for Computational Linguistics,2018:163-169.
[7]Wu Chaoyuan,Ahmed A,Beutel A,et al.Recurrent recommender networks[C]//Proc of the 10th ACM International Conference on Web Search and Data Mining.New York:ACM Press,2017:495-503.
[8]Tang Jiaxi,Belletti F,Jain S,et al.Towards neural mixture recommender for long range dependent user sequences[C]//Proc of the 28th International Conference on World Wide Web.New York:ACM Press,2019:1782-1793.
[9]Ying Haochao,Zhuang Fuzhen,Zhang Fu Zhang,et al.Sequential re-commender system based on hierarchical attention networks[C]//Proc of the 27th International Joint Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2018:3926-3932.
[10]伍鑫,黃勃,方志軍,等.序列生成對(duì)抗網(wǎng)絡(luò)在推薦系統(tǒng)中的應(yīng)用[J].計(jì)算機(jī)工程與應(yīng)用,2020,56(23):175-179.(Wu Xin,Huang Bo,F(xiàn)ang Zhijun,et al.Application of sequence generative adversarial network in recommendation system[J].Computer Engineering and Applications,2020,56(23):175-179.)
[11]Che Tong,Li Yanren,Zhang Ruixiang,et al.Maximum-likelihood augmented discrete generative adversarial networks[EB/OL].(2017-02-26).https://arxiv .org/abs/1 702.07983.
[12]Yu Lantao,Zhang Weinan,Wang Jun,et al.SeqGAN:sequence gene-rative adversarial nets with policy gradient[C]//Proc of AAAI Confe-rence on Artificial Intelligence.Palo Alto,CA:AAAI Press,2017:2852-2858.
[13]Kusner M J,Hernández-Lobato J M.GANs for sequences of discrete elements with the Gumbel-softmax distribution[EB/OL].(2016-11-16).https://arxiv .org/abs/1611.04051v1.
[14]Lu Sidi,Yu Lantao,F(xiàn)eng Siyuan,et al.CoT:cooperative training for generative modeling of discrete data[C]//Proc of International Conference on Machine Learning.2019:4164-4172.
[15]Goodfellow I J,Pouget-Abadie J,Mirza M,et al.Generative adversarial nets[C]//Proc of the 27th International Conference on Neural Information Processing Systems.Cambridge,MA:MIT Press,2014:2672-2680.
[16]Zhu Yaoming,Lu Sidi,Zheng Lei,et al.Texygen:a benchmarking platform for text generation models[C]//Proc of the 41st International ACM SIGIR Conference on Research amp; Development in Information Retrieval.New York:ACM Press,2018:1097-1100.
[17]Lin K,Li Diang,He Xiaodong,et al.Adversarial ranking for language generation[C]//Proc of the 31st International Conference on Neural Information Processing Systems.Red Hook,NY:Curran Associates Inc.,2018:3158-3168.
[18]Guo Jiaxian,Lu Sidi,Cai Han,et al.Long text generation via adversa-rial training with leaked information[C]//Proc of AAAI Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2018:5141-5148.
[19]Lin T Y,Maire M,Belongie S,et al.Microsoft COCO:common objects in context[C]//Proc of European Conference on Computer Vision.Cham:Springer,2014:740-755.