鄭 賽,李天瑞*,黃 維
(1.西南交通大學(xué) 計(jì)算機(jī)與人工智能學(xué)院,成都 611756;2.云計(jì)算與智能技術(shù)四川省高校重點(diǎn)實(shí)驗(yàn)室(西南交通大學(xué)),成都 611756)
在過去幾年里,機(jī)器學(xué)習(xí)在人工智能應(yīng)用領(lǐng)域迅速發(fā)展,這些機(jī)器學(xué)習(xí)技術(shù)尤其是深度學(xué)習(xí)[1]的成功,都建立在大量數(shù)據(jù)的基礎(chǔ)上。然而,我們在現(xiàn)實(shí)世界中遇到的數(shù)據(jù)往往是小規(guī)模的、碎片化的,例如,來自移動(dòng)終端設(shè)備、物聯(lián)網(wǎng)設(shè)備的數(shù)據(jù)和大量分布在城市中的傳感器的數(shù)據(jù)都擁有這兩個(gè)特點(diǎn)。出于保護(hù)用戶隱私和數(shù)據(jù)安全的要求,簡單地將這些數(shù)據(jù)聚合在一起進(jìn)行模型訓(xùn)練是不可行的。2018 年,歐盟開始執(zhí)行《通用數(shù)據(jù)保護(hù)條例》;2021 年11 月,我國開始正式實(shí)施《中華人民共和國信息保護(hù)法》,國內(nèi)數(shù)據(jù)監(jiān)管法律對數(shù)據(jù)隱私保護(hù)的監(jiān)管也愈發(fā)嚴(yán)格。
聯(lián)邦學(xué)習(xí)(Federated Learning,F(xiàn)L)是一種解決上述問題的機(jī)器學(xué)習(xí)設(shè)置。聯(lián)邦學(xué)習(xí)這一概念由McMahan 等[2]于2017 年首次提出,最近的研究者們對聯(lián)邦學(xué)習(xí)提出了一個(gè)更寬泛和準(zhǔn)確的定義[3]:聯(lián)邦學(xué)習(xí)是一種機(jī)器學(xué)習(xí)設(shè)置,其中多個(gè)客戶端在中央服務(wù)器或服務(wù)提供商的協(xié)調(diào)下協(xié)作解決機(jī)器學(xué)習(xí)問題。每個(gè)客戶端的原始數(shù)據(jù)都存儲在本地,不進(jìn)行交換或傳輸;取而代之的是,使用實(shí)時(shí)聚合的模型更新來實(shí)現(xiàn)學(xué)習(xí)目標(biāo)。聯(lián)邦平均(Federated Averaging,F(xiàn)edAvg)算法是一種典型的聯(lián)邦學(xué)習(xí)算法,主要包含兩個(gè)步驟:客戶端接收服務(wù)器發(fā)送來的全局模型并進(jìn)行訓(xùn)練得到本地模型;服務(wù)器接收來自多個(gè)客戶端的本地模型,通過加權(quán)平均這些本地模型得到一個(gè)新的全局模型,再將其發(fā)送回客戶端。雖然FedAvg 一定程度上解決了聯(lián)邦學(xué)習(xí)中的兩個(gè)核心問題:統(tǒng)計(jì)異質(zhì)性問題和通信成本問題,但是,作為一個(gè)樸素的聯(lián)邦學(xué)習(xí)算法,它仍然有許多地方可以改進(jìn)。
本文提出一個(gè)結(jié)合生成模型和深度遷移學(xué)習(xí)的聯(lián)邦學(xué)習(xí)算法FedGT(Federated Generative Transfer)。該算法只需要客戶端和服務(wù)器進(jìn)行一輪通信,能大幅降低聯(lián)邦學(xué)習(xí)的通信成本。同時(shí),由于客戶端得到了個(gè)性化模型,所以統(tǒng)計(jì)異質(zhì)性問題也得到了一定程度的緩解。本文工作的主要特點(diǎn)在于引入生成模型和模擬數(shù)據(jù)在服務(wù)器構(gòu)建全局模型,并且僅需一輪通信。此前的算法都只傳輸預(yù)測模型參數(shù)并進(jìn)行聚合,而本文算法則通過傳輸生成模型來生成模擬數(shù)據(jù)。具體工作如下:
1)利用生成模型在服務(wù)器生成模擬數(shù)據(jù)來建立全局預(yù)測模型,可以在一輪通信下保證聯(lián)邦學(xué)習(xí)的最終性能。
2)使用深度遷移學(xué)習(xí)中的微調(diào)來進(jìn)一步適應(yīng)客戶端不同分布的數(shù)據(jù),從而緩解統(tǒng)計(jì)異質(zhì)性的問題。
3)在不同的數(shù)據(jù)集中使用不同的模型進(jìn)行實(shí)驗(yàn),結(jié)果表明本文方法具有一定的通用性。
作為一種能夠解決數(shù)據(jù)孤島和數(shù)據(jù)隱私安全問題的機(jī)器學(xué)習(xí)設(shè)置,相關(guān)研究者已經(jīng)將聯(lián)邦學(xué)習(xí)應(yīng)用到了許多領(lǐng)域:Hard 等[4]將FedAvg 算法用于預(yù)測智能手機(jī)鍵盤輸入法的下一個(gè)單詞;史鼎元等[5]將聯(lián)邦學(xué)習(xí)應(yīng)用在信息檢索領(lǐng)域,提出了聯(lián)邦排序?qū)W習(xí)算法;Muhammad 等[6]將聯(lián)邦學(xué)習(xí)應(yīng)用在推薦系統(tǒng)。然而,聯(lián)邦學(xué)習(xí)這一技術(shù)的落地仍然面臨許多挑戰(zhàn)。
聯(lián)邦學(xué)習(xí)環(huán)境中的客戶端通常是來自現(xiàn)實(shí)世界的終端設(shè)備,例如移動(dòng)電話、可穿戴設(shè)備和智能設(shè)備。由于時(shí)間、地理位置和用戶習(xí)慣等因素,這些設(shè)備上的數(shù)據(jù)是非獨(dú)立同分布(non-Independent and Identically Distributed,non-IID)的,這一問題被稱為聯(lián)邦學(xué)習(xí)的統(tǒng)計(jì)異質(zhì)性[7]問題。統(tǒng)計(jì)異質(zhì)性導(dǎo)致了一些基于數(shù)據(jù)獨(dú)立同分布假設(shè)的傳統(tǒng)分布式機(jī)器學(xué)習(xí)算法性能低下。為了解決這一問題,聯(lián)邦學(xué)習(xí)改進(jìn)算法被相繼提出,例如,F(xiàn)edProx[8]是一個(gè)面向統(tǒng)計(jì)異質(zhì)性的算法,它在每個(gè)客戶端的原有優(yōu)化目標(biāo)上增加了一個(gè)衡量本地模型和全局模型差異的L2 正則化項(xiàng),使得模型在non-IID 數(shù)據(jù)上訓(xùn)練更加穩(wěn)定,提高了收斂速度。
由于分布在客戶端的數(shù)據(jù)是non-IID 的,這使訓(xùn)練單個(gè)全局模型難以適用于所有客戶端,所以為每個(gè)客戶端構(gòu)建個(gè)性化模型十分重要。個(gè)性化聯(lián)邦學(xué)習(xí)[9]算法通常與遷移學(xué)習(xí)[10]、知識蒸餾[11]、元學(xué)習(xí)[12]、多任務(wù)學(xué)習(xí)[13]等其他機(jī)器學(xué)習(xí)技術(shù)相結(jié)合。遷移學(xué)習(xí)使深度學(xué)習(xí)模型能夠利用在解決一個(gè)問題時(shí)獲得的知識來解決另一個(gè)相關(guān)問題,Wang 等[14]利用客戶端的本地?cái)?shù)據(jù)對全局模型參數(shù)進(jìn)行再次更新,從而得到客戶端的個(gè)性化模型。知識蒸餾通過讓學(xué)生網(wǎng)絡(luò)模仿教師網(wǎng)絡(luò),將大型教師網(wǎng)絡(luò)中的知識提取到小型學(xué)生網(wǎng)絡(luò)中,減少了網(wǎng)絡(luò)的參數(shù)量。過擬合是聯(lián)邦學(xué)習(xí)個(gè)性化模型的一個(gè)重要挑戰(zhàn),Yu 等[15]提出將全局模型視為教師、將個(gè)性化模型視為學(xué)生,通過知識蒸餾來減輕個(gè)性化過程中過擬合的影響。Li 等[16]提出了FedMD 算法,這是一個(gè)基于知識蒸餾和遷移學(xué)習(xí)的聯(lián)邦學(xué)習(xí)框架,它允許客戶端使用本地私有數(shù)據(jù)集和全球公共數(shù)據(jù)集單獨(dú)訓(xùn)練個(gè)性化模型。在多任務(wù)學(xué)習(xí)中,多個(gè)相關(guān)任務(wù)被同時(shí)解決,模型通過利用各個(gè)任務(wù)的共性和差異性來達(dá)到更好的訓(xùn)練效果。Smith 等[13]表明多任務(wù)學(xué)習(xí)是構(gòu)建個(gè)性化聯(lián)邦模型的一種合理方式,提出了一個(gè)聯(lián)邦多任務(wù)學(xué)習(xí)框架來應(yīng)對聯(lián)邦學(xué)習(xí)中與通信、掉隊(duì)和容錯(cuò)相關(guān)的挑戰(zhàn)。
聯(lián)邦學(xué)習(xí)中的另一個(gè)核心挑戰(zhàn)是高昂的通信成本?,F(xiàn)實(shí)世界中終端設(shè)備數(shù)量龐大,通信環(huán)境復(fù)雜等因素導(dǎo)致了聯(lián)邦學(xué)習(xí)的通信壓力非常大。為了減少在這種復(fù)雜環(huán)境下的通信量,面向通信成本優(yōu)化的聯(lián)邦學(xué)習(xí)算法主要從這兩方面進(jìn)行研究:減少客戶端和服務(wù)器的通信次數(shù);減小每次通信中傳輸?shù)臄?shù)據(jù)規(guī)模。例如,Yao 等[17]在客戶端的原有優(yōu)化目標(biāo)上增加了基于最大均值差異(Maximum Mean Discrepancy,MMD)距離的加權(quán)差異項(xiàng),通過MMD 距離來衡量全局模型和本地模型的差異,從而加速全局模型的收斂,減少訓(xùn)練過程中的通信次數(shù)。而Caldas 等[18]受到常被用于防止模型過擬合的隨機(jī)失活算法[19]的啟發(fā),提出了聯(lián)邦隨機(jī)失活算法。在每個(gè)全連接層上,該算法丟棄固定數(shù)量的全連接層參數(shù),但保證相鄰兩層失活后的輸出矩陣的維度仍然能夠進(jìn)行矩陣運(yùn)算;而在每個(gè)卷積層上,該算法通過丟棄固定數(shù)量的卷積核來減少參數(shù)。在傳統(tǒng)的隨機(jī)失活算法中,失活后的模型仍然具有和失活前模型一樣的大小,而在聯(lián)邦隨機(jī)失活算法中,因?yàn)橹粋鬏敿せ畹膮?shù),所以能顯著減少每輪的通信量。除了減少通信次數(shù)和減小通信數(shù)據(jù)規(guī)模這兩個(gè)方面,異步通信[20-21]和通信拓?fù)鋬?yōu)化[22]也是兩個(gè)重要的面向通信成本優(yōu)化的研究方向,但是由于實(shí)現(xiàn)技術(shù)難度大等原因,目前這兩個(gè)方向的相關(guān)研究還較少。
自動(dòng)編碼器(AutoEncoder,AE)和生成對抗網(wǎng)絡(luò)(Generative Adversarial Net,GAN)[23]是兩種著名的生成模型。AE 和GAN 都有許多變體,如變分自動(dòng)編碼器(Variational AutoEncoder,VAE)[24]、WGAN(Wasserstein Generative Adversarial Network)[25]、條件生成對抗網(wǎng)絡(luò)(Conditional Generative Adversarial Net,CGAN)[26]等。VAE由編碼器和解碼器組成,編碼器將數(shù)據(jù)樣本x編碼為隱層表示z,解碼器將隱層表示z解碼回?cái)?shù)據(jù)空間,這兩個(gè)過程可以分別表示為:
VAE 的訓(xùn)練目標(biāo)是使重建誤差盡可能小,即使x和盡可能接近。VAE 損失函數(shù)如下所示:
其中DKL指的是KL 散度。
GAN 同樣也包含編碼器和解碼器,通常被稱為生成器網(wǎng)絡(luò)G(z)的解碼器將隱層表示z映射到數(shù)據(jù)空間,同時(shí)通常被稱為判別器網(wǎng)絡(luò)D(x)的編碼器將訓(xùn)練一個(gè)代表數(shù)據(jù)真實(shí)性的概率y=D(x) ∈[0,1],其中:y越接近1,代表x是真實(shí)數(shù)據(jù)的概率越大;y越接近0,則代表x來自生成器網(wǎng)絡(luò)G(z)的概率越大。
生成器網(wǎng)絡(luò)和判別器網(wǎng)絡(luò)被同時(shí)訓(xùn)練:更新G的網(wǎng)絡(luò)參數(shù)來最小化ln( 1-D(G(z))),更新D的網(wǎng)絡(luò)參數(shù)來最小化ln(D(x)),二者進(jìn)行著一種兩方最大最小博弈(Two-player min-max game),其值函數(shù)為:
遷移學(xué)習(xí)可以解決機(jī)器學(xué)習(xí)中訓(xùn)練數(shù)據(jù)不足的問題,它試圖通過放寬訓(xùn)練數(shù)據(jù)和測試數(shù)據(jù)必須是獨(dú)立同分布的假設(shè),將知識從源域轉(zhuǎn)移到目標(biāo)域。遷移學(xué)習(xí)的基本方法可以分為四類:基于樣本的遷移、基于模型的遷移、基于特征的遷移和基于關(guān)系的遷移?;谔卣骱突谀P偷倪w移通常表現(xiàn)得更好,這也是目前大多數(shù)遷移學(xué)習(xí)工作的研究熱點(diǎn)。由于深度學(xué)習(xí)在許多研究領(lǐng)域取得了主導(dǎo)地位,研究如何通過深度神經(jīng)網(wǎng)絡(luò)有效地轉(zhuǎn)移知識也變得至關(guān)重要,這類方法被稱為深度遷移學(xué)習(xí)。深度遷移學(xué)習(xí)可以分為四類:基于實(shí)例的深度遷移學(xué)習(xí)、基于映射的深度遷移學(xué)習(xí)、基于網(wǎng)絡(luò)的深度遷移學(xué)習(xí)和基于對抗的深度遷移學(xué)習(xí)。基于網(wǎng)絡(luò)的深度遷移學(xué)習(xí)主要通過重復(fù)使用在源域中預(yù)訓(xùn)練的部分網(wǎng)絡(luò)來遷移知識,微調(diào)[27]就是一種基于網(wǎng)絡(luò)的深度遷移學(xué)習(xí)方法,它的主要思想如下:
深度神經(jīng)網(wǎng)絡(luò)的淺層通常學(xué)習(xí)數(shù)據(jù)的一般性特征,但隨著網(wǎng)絡(luò)的深入,深層網(wǎng)絡(luò)更注重學(xué)習(xí)特定性特征。因此,當(dāng)有一個(gè)完成訓(xùn)練的模型時(shí),可以通過凍結(jié)淺層網(wǎng)絡(luò)的參數(shù)并更新深層網(wǎng)絡(luò)的參數(shù),將該模型快速應(yīng)用于新的數(shù)據(jù)集或者訓(xùn)練任務(wù)。
類似地,在聯(lián)邦學(xué)習(xí)中,全局模型學(xué)習(xí)一般性特征,而局部模型學(xué)習(xí)特定性特征。可以使用全局模型學(xué)習(xí)一般性特征,通過微調(diào)在每個(gè)客戶端上學(xué)習(xí)特定性特征,從而快速訓(xùn)練本地模型。
在介紹本文提出的單輪通信聯(lián)邦學(xué)習(xí)算法FedGT 之前,先通過介紹FedAvg 算法來了解聯(lián)邦學(xué)習(xí)算法的基本流程。
FedAvg 算法的參與者有1 個(gè)中心服務(wù)器和N個(gè)客戶端,首先中心服務(wù)器要初始化模型參數(shù)w0,然后進(jìn)行T輪迭代:服務(wù)器將模型參數(shù)發(fā)送給隨機(jī)選出的K個(gè)客戶端,客戶端接收到模型參數(shù)對其更新后發(fā)回服務(wù)器,最后服務(wù)器聚合各個(gè)客戶端的模型參數(shù)得到新的模型參數(shù)。
算法1 FedAvg 算法。
不同于FedAvg 算法的多輪通信,F(xiàn)edGT 算法僅在客戶端和服務(wù)器之間進(jìn)行一輪通信。FedGT 算法主要包括三個(gè)步驟:各個(gè)客戶端利用本地?cái)?shù)據(jù)訓(xùn)練一個(gè)用于生成數(shù)據(jù)樣本的生成模型和一個(gè)用于推斷標(biāo)簽的局部預(yù)測模型,然后將這兩個(gè)模型的參數(shù)發(fā)送給服務(wù)器;服務(wù)器利用各個(gè)客戶端的生成模型生成數(shù)據(jù)樣本,然后再用客戶端的預(yù)測模型給這些樣本打標(biāo)簽,從而得到一個(gè)模擬數(shù)據(jù)集,服務(wù)器再利用該模擬數(shù)據(jù)集訓(xùn)練一個(gè)全局的預(yù)測模型并發(fā)送給客戶端;各個(gè)客戶端收到全局預(yù)測模型后再次利用全局預(yù)測模型和本地真實(shí)數(shù)據(jù)訓(xùn)練出個(gè)性化本地預(yù)測模型。FedGT算法流程如圖1所示。
圖1 FedGT算法流程Fig.1 Flowchart of FedGT algorithm
FedGT 算法會在服務(wù)器生成一個(gè)模擬數(shù)據(jù)集,并通過該模擬數(shù)據(jù)集訓(xùn)練出能代表數(shù)據(jù)一般性特征的全局預(yù)測模型。據(jù)我們所知,生成模擬數(shù)據(jù)的方式通常分為兩類:一類為不考慮數(shù)據(jù)各個(gè)維度相關(guān)性的分布擬合法;一類為考慮數(shù)據(jù)各個(gè)維度相關(guān)性的神經(jīng)網(wǎng)絡(luò)生成模型。
分布擬合法可以分為兩個(gè)步驟:選擇不同的數(shù)據(jù)分布,使用統(tǒng)計(jì)方法估計(jì)這些分布的參數(shù)值;確定哪個(gè)分布更加符合數(shù)據(jù)樣本,或者說,選出p值最大的分布作為最終分布。具體地,分布一般由四個(gè)參數(shù)定義:位置、規(guī)模、形狀和閾值。這些參數(shù)定義了不同的分布:位置參數(shù)規(guī)定了分布在X軸上的位置;規(guī)模參數(shù)決定了分布中的擴(kuò)散程度;形狀參數(shù)使分布具有不同的形狀;閾值參數(shù)則定義了分布在X軸上的最小值。分布的參數(shù)可以用各種統(tǒng)計(jì)方法來估計(jì)。例如最大似然估計(jì)法通過最小化負(fù)對數(shù)似然函數(shù)值來求得對分布參數(shù)的估計(jì)值。然后可以采用例如Kolmogorov-Smirnov 檢驗(yàn)的方式計(jì)算出各個(gè)分布在該數(shù)據(jù)上的p值,選擇出p值最大的分布即為最終確定的該數(shù)據(jù)的分布。
分布擬合法在采用不同分布擬合數(shù)據(jù)樣本時(shí),是對數(shù)據(jù)的每一個(gè)維度進(jìn)行單獨(dú)擬合,所以它并沒有考慮數(shù)據(jù)各維度的相關(guān)性;而現(xiàn)實(shí)世界的數(shù)據(jù)樣本各個(gè)維度存在極大的相關(guān)性,例如圖像數(shù)據(jù)的鄰近像素點(diǎn),文本數(shù)據(jù)的上下文都說明數(shù)據(jù)各維度存在極大相關(guān)性?;谏疃葘W(xué)習(xí)的一些生成模型考慮了數(shù)據(jù)各維度的相關(guān)性,例如典型的全連接層輸出特征各個(gè)維度在計(jì)算時(shí)都將輸入特征的各個(gè)維度乘以權(quán)重參數(shù)。
這些生成模型通常由編碼器和解碼器兩個(gè)部分組成,模型訓(xùn)練的目標(biāo)為最小化解碼器的重建誤差。在數(shù)據(jù)生成階段,只需將噪聲輸入到解碼器,解碼器即可輸出生成的模擬數(shù)據(jù)。VAE 和GAN 主要在編碼器和解碼器的設(shè)計(jì)上具有明顯差別:VAE 模型參數(shù)少,模型結(jié)構(gòu)簡單,易于調(diào)試,但是在復(fù)雜數(shù)據(jù)上的表現(xiàn)不佳;而GAN 通??梢杂糜诟訌?fù)雜的數(shù)據(jù)的生成,但其模型訓(xùn)練過程參數(shù)的調(diào)試比較具有挑戰(zhàn)性。所以FedGT 算法使用VAE 在簡單的數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),而在較復(fù)雜的數(shù)據(jù)集上則采用GAN 進(jìn)行實(shí)驗(yàn)。
建立生成模型可以得到模擬數(shù)據(jù)樣本X,然而僅通過沒有標(biāo)簽的數(shù)據(jù)樣本無法建立預(yù)測模型,所以客戶端需要建立預(yù)測模型發(fā)送到服務(wù)器,從而得到服務(wù)器生成的模擬數(shù)據(jù)的標(biāo)簽Y。本文選取了兩個(gè)圖像分類任務(wù)的數(shù)據(jù)集分別進(jìn)行實(shí)驗(yàn),客戶端的預(yù)測模型分別為簡單卷積網(wǎng)絡(luò)(Simple Convolutional Neural Network,Simple-CNN)和修改后的ResNet-18[28]。
對于一個(gè)客戶端,服務(wù)器會接收到的內(nèi)容包括:生成模型的解碼器Deci(z),本地預(yù)測模型,數(shù)據(jù)數(shù)量numi。生成模型解碼器Deci(z)輸入噪聲后可以得到模擬數(shù)據(jù)樣本,生成的數(shù)量為numi,這些模擬數(shù)據(jù)會被輸入客戶端的本地預(yù)測模型得到對應(yīng)的預(yù)測值。于是服務(wù)器在接受到來自一個(gè)客戶端的內(nèi)容后,可以產(chǎn)生一個(gè)樣本數(shù)量為numi的子數(shù)據(jù)集。當(dāng)全部N個(gè)客戶端都發(fā)送內(nèi)容到服務(wù)器后,服務(wù)器得到N個(gè)子數(shù)據(jù)集后,將這些數(shù)據(jù)集合并為,使用該數(shù)據(jù)集訓(xùn)練出全局預(yù)測模型Pglobal并將其發(fā)送給各個(gè)客戶端。這些步驟在算法2 中進(jìn)一步介紹。
FedAvg 和FedProx 等算法在得到全局預(yù)測模型后算法就結(jié)束了,所以每個(gè)客戶端最終得到的模型是相同的全局模型;但是由于每個(gè)客戶端上的數(shù)據(jù)通常是non-IID 的,同樣的模型在某些客戶端上表現(xiàn)良好,但在另一些客戶端上會表現(xiàn)得很糟糕。對于這個(gè)問題,一個(gè)更好的解決方案是對于每個(gè)客戶端單獨(dú)訓(xùn)練個(gè)性化模型。在FedGT 算法中,客戶端在接收到全局預(yù)測模型后,會利用本地的數(shù)據(jù)通過微調(diào)來得到個(gè)性化模型。具體地,客戶端會凍結(jié)網(wǎng)絡(luò)模型的淺層網(wǎng)絡(luò)參數(shù),利用本地?cái)?shù)據(jù)對深層網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)整。
算法2 FedGT 算法。
本文將FedGT 與FedAvg、FedProx 和集中式學(xué)習(xí)進(jìn)行了比較。集中式學(xué)習(xí)是指所有客戶端將它們的數(shù)據(jù)傳輸?shù)揭粋€(gè)中央服務(wù)器,然后服務(wù)器使用這些數(shù)據(jù)建立一個(gè)全局模型,并將全局模型參數(shù)發(fā)回給客戶端,這種方法的一個(gè)重要問題是傳輸客戶端真實(shí)數(shù)據(jù)的同時(shí)伴隨著數(shù)據(jù)泄露的風(fēng)險(xiǎn)。
為了說明算法的通用性,分別在CIFAR-10 和MNIST 數(shù)據(jù)集上采用了不同的生成模型和預(yù)測模型進(jìn)行實(shí)驗(yàn)。MNIST 數(shù)據(jù)集包含了7 × 104張手寫數(shù)字的灰度圖像,所有圖像被分為10 類,分別為手寫數(shù)字0~9,并且每張圖像已經(jīng)被標(biāo)準(zhǔn)化為28 × 28,訓(xùn)練集包含6 × 104張圖像,測試集包含1 × 104張圖像。CIFAR-10 數(shù)據(jù)集由6 × 104張32 × 32 的彩色圖像組成,所有圖像被分為10 類,分別為飛機(jī)、汽車、鳥、貓、鹿、狗、青蛙、馬、船和卡車,每類有6 × 103張圖像,包括5 × 103張訓(xùn)練圖像和1 × 103張測試圖像,數(shù)據(jù)集總共有5 × 104張訓(xùn)練圖像和1 × 104張測試圖像。
在MNIST 數(shù)據(jù)集上使用VAE 作為生成模型,使用Simple-CNN 作為預(yù)測模型。Simple-CNN 的結(jié)構(gòu)如表1 所示,除了表1 中提到的卷積層和池化層,使用ReLU 函數(shù)作為網(wǎng)絡(luò)的激活函數(shù),并增加了隨機(jī)失活層防止過擬合。
表1 Simple-CNN的結(jié)構(gòu)Tab.1 Structure of Simple-CNN
對于數(shù)據(jù)集在客戶端上的劃分,采用了IID 和non-IID 兩種方式。首先將訓(xùn)練集和測試集合并得到7 × 104張圖像,然后設(shè)置了20 個(gè)客戶端,IID 的數(shù)據(jù)劃分方式為:每個(gè)客戶端互不重復(fù)地隨機(jī)選取3.5 × 103張圖像作為本地?cái)?shù)據(jù)集,其中6/7 被作為本地訓(xùn)練集,1/7 被作為本地測試集。而non-IID 的劃分方式為:將所有數(shù)據(jù)切分成40 份,其中每份數(shù)據(jù)只包含10 個(gè)類別中的1 個(gè)類別,每個(gè)客戶端互不重復(fù)地隨機(jī)選取2 份數(shù)據(jù)作為本地?cái)?shù)據(jù)集,這保證了每個(gè)客戶端上至多有2 種類別的數(shù)據(jù),同樣地,non-IID 劃分的訓(xùn)練集-測試集比例為6∶1。
在FedGT 算法的步驟3 中,客戶端接收到全局預(yù)測模型進(jìn)行微調(diào)時(shí),凍結(jié)Conv1、Conv2 和FC1 這三層的參數(shù),只更新FC2 這一全連接層的參數(shù)。使用Adam 作為優(yōu)化器,使用交叉熵作為損失函數(shù),每批次數(shù)據(jù)量為64,訓(xùn)練輪數(shù)為500,學(xué)習(xí)率為5 × 10-4。
實(shí)驗(yàn)選取了集中式學(xué)習(xí),F(xiàn)edAvg、FedProx、FedMD 和FedDyn[29]作為基準(zhǔn)算法與FedGT 算法進(jìn)行比較。其中FedAvg 和FedProx 作為聯(lián)邦學(xué)習(xí)中較早發(fā)表的經(jīng)典算法,常被用于各類聯(lián)邦學(xué)習(xí)算法實(shí)驗(yàn)的基準(zhǔn)算法。FedDyn 是目前最新的聯(lián)邦學(xué)習(xí)算法,它在客戶端本地模型更新的損失函數(shù)中加入了一個(gè)動(dòng)態(tài)的正則器,使客戶端每輪的損失函數(shù)動(dòng)態(tài)更新,同時(shí)使全局經(jīng)驗(yàn)損失和局部經(jīng)驗(yàn)損失的最小值保持漸進(jìn)一致,F(xiàn)edDyn 算法的通信效率和準(zhǔn)確率均優(yōu)于FedAvg 和FedPox。FedMD 是個(gè)性化聯(lián)邦學(xué)習(xí)中的一個(gè)重要算法,它將知識蒸餾和遷移學(xué)習(xí)同時(shí)運(yùn)用在聯(lián)邦學(xué)習(xí)訓(xùn)練過程中,允許客戶端使用本地私有數(shù)據(jù)集和全球公共數(shù)據(jù)集單獨(dú)訓(xùn)練個(gè)性化模型。
基準(zhǔn)算法在訓(xùn)練過程中使用隨機(jī)梯度下降(Stochastic Gradient Descent,SGD)作為優(yōu)化器,使用交叉熵作為損失函數(shù),每批次數(shù)據(jù)量為64,訓(xùn)練輪次為200,初始學(xué)習(xí)率為0.1,每50 輪衰減80%。
表2 列出了所選基準(zhǔn)算法和FedGT 算法在MNIST 數(shù)據(jù)集上的準(zhǔn)確率,由于集中式學(xué)習(xí)是將數(shù)據(jù)集中到服務(wù)器上進(jìn)行訓(xùn)練,所以它在non-IID 劃分和IID劃分下的準(zhǔn)確率是相同的。FedGT 算法在IID 和non-IID 數(shù)據(jù)劃分下的準(zhǔn)確率超過了實(shí)驗(yàn)對比的所有基準(zhǔn)算法,然而不同于基準(zhǔn)算法,F(xiàn)edGT 算法在non-IID數(shù)據(jù)劃分下的準(zhǔn)確率比IID數(shù)據(jù)劃分的準(zhǔn)確率高。
表2 不同算法在MNIST數(shù)據(jù)集上的準(zhǔn)確率 單位:%Tab.2 Accuracies of different algorithms on MNIST dataset unit:%
在CIFAR-10 數(shù)據(jù)集上使用GAN 作為生成模型,采用修改后的ResNet-18 作為預(yù)測模型。修改后的ResNet-18 殘差單元的結(jié)構(gòu)如圖2 所示,圖中Conv1、Conv2 和Conv3 即為表3中的Conv1、Conv2 和Conv3。
表3 修改后的ResNet-18網(wǎng)絡(luò)結(jié)構(gòu)Tab.3 Modified ResNet-18 network structure
圖2 修改后的ResNet-18殘差單元結(jié)構(gòu)Fig.2 Structure of residual unit of modified ResNet-18
由于ResNet-18 網(wǎng)絡(luò)輸入圖像的尺寸為3 × 224 × 224,而CIFAR10 的圖像尺寸為3 × 32 × 32,所以對ResNet-18 進(jìn)行了一定調(diào)整,主要改動(dòng)為:將第一個(gè)卷積核為7 × 7 的卷積層以及一個(gè)最大池化層替換為一個(gè)卷積核為7 × 7 的卷積層,以此來適應(yīng)CIFAR-10 的圖像尺寸,具體網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)如表3 所示。其中:Conv2d(3,1)代表卷積核為3 × 3、步長為1的二維卷積層;ResUnit(n,m,k1,k2,k3)代表一個(gè)如圖2 所示的殘差單元,Conv1、Conv2 和Conv3 的卷積核都為n×n,輸出通道都為m,步長分別為k1、k2 和k3;AvgPool(4,4)代表核為4 × 4、步長為4 的平均池化層;Linear(512,10)代表輸入為512 維向量、輸出為10 維向量的全連接層。
對于數(shù)據(jù)集在客戶端上的劃分,采用和在MNIST 數(shù)據(jù)集上相同的IID 劃分和non-IID 劃分策略,首先合并訓(xùn)練集和測試集,客戶端數(shù)量為20,如果為IID 劃分則隨機(jī)選取,如果為non-IID 劃分則使每個(gè)客戶端至多有兩個(gè)類別的樣本,客戶端上訓(xùn)練集和測試集的比例保持和原數(shù)據(jù)集一致,都為5∶1。
客戶端接收到全局預(yù)測模型進(jìn)行微調(diào)時(shí),凍結(jié)ConvIn、Layer1、Layer2 和Layer3 層的參數(shù),只更新Layer4 和Linear。在FedGT 和基準(zhǔn)算法上使用的實(shí)驗(yàn)參數(shù)與在MNIST 上的實(shí)驗(yàn)參數(shù)一致,表4 列出了實(shí)驗(yàn)結(jié)果?;鶞?zhǔn)算法在CIFAR-10上的準(zhǔn)確率整體上低于MNIST 上的準(zhǔn)確率,F(xiàn)edGT 算法的準(zhǔn)確率在IID 和non-IID 數(shù)據(jù)劃分上都優(yōu)于基準(zhǔn)算法。
表4 算法在CIFAR-10數(shù)據(jù)集上的準(zhǔn)確率 單位:%Tab.4 Accuracies of algorithms on CIFAR-10 dataset unit:%
FedGT 算法的主要目的是減少聯(lián)邦學(xué)習(xí)中的通信輪數(shù)和通信數(shù)據(jù)量。通常一個(gè)聯(lián)邦學(xué)習(xí)模型需要進(jìn)行E輪通信,每輪通信均傳輸預(yù)測模型參數(shù);而FedGT 算法只需要一輪通信,客戶端發(fā)送給服務(wù)器生成模型解碼器和本地預(yù)測模型的參數(shù),服務(wù)器發(fā)送給客戶端全局預(yù)測模型參數(shù)。表5 中對比了FedGT 和基準(zhǔn)算法在兩個(gè)數(shù)據(jù)集上的通信量。由表5 可以看出:在MNIST 數(shù)據(jù)集上,F(xiàn)edGT 算法的通信量約為FedAvg、FedProx、FedDyn 算法通信量的1/10,約為FedMD 算法的1/100;在CIFAR-10 數(shù)據(jù)集上,F(xiàn)edGT 算法的通信量約為FedAvg、FedProx、FedDyn 算法的通信量的1/100,約為FedMD算法的1/10。
表5 MNIST和CIFAR-10數(shù)據(jù)集上的通信效率對比Tab.5 Comparison of communication efficiency on MNIST and CIFAR-10 datasets
聯(lián)邦學(xué)習(xí)通信成本的減少會帶來計(jì)算量的增加,本文以每秒浮點(diǎn)運(yùn)算次數(shù)(Floating-Point Operations Per Second,F(xiàn)LOPS)為單位計(jì)算了模型在前向和反向傳播時(shí)的計(jì)算量,并考慮了損失函數(shù)不同對反向傳播過程中計(jì)算量的影響和服務(wù)器端聚合模型的計(jì)算開銷,最終得到的計(jì)算量如表6 所示,其中客戶端這一欄為單個(gè)客戶端的計(jì)算開銷。計(jì)算結(jié)果顯示,F(xiàn)edGT 算法增加了聯(lián)邦學(xué)習(xí)在服務(wù)器上的計(jì)算開銷,但減少了客戶端上的計(jì)算開銷。聯(lián)邦學(xué)習(xí)的客戶端通常為移動(dòng)智能設(shè)備,只有當(dāng)設(shè)備空閑時(shí)才能進(jìn)行模型訓(xùn)練,而這些設(shè)備計(jì)算能力遠(yuǎn)不如計(jì)算能力強(qiáng)大的服務(wù)器,所以FedGT算法減少客戶端計(jì)算開銷、增加服務(wù)器計(jì)算開銷有利于聯(lián)邦學(xué)習(xí)的落地。
表6 MNIST和CIFAR-10數(shù)據(jù)集上的計(jì)算開銷對比 單位:FLOPsTab.6 Comparison of computing overhead on MNIST and CIFAR-10 datasets unit unit:FLOPs
本文提出了一種新的聯(lián)邦學(xué)習(xí)算法——FedGT 算法。該算法采用服務(wù)器生成模擬數(shù)據(jù)來訓(xùn)練全局模型,這一做法能夠?qū)⑼ㄐ诺妮啍?shù)減少至一輪,并且還在客戶端通過模型的微調(diào)實(shí)現(xiàn)模型個(gè)性化來解決客戶端異質(zhì)性問題。FedGT 算法在不同的數(shù)據(jù)集上使用不同的架構(gòu)進(jìn)行了實(shí)驗(yàn),結(jié)果表明它在數(shù)據(jù)以IID 和non-IID 方式分布時(shí)均優(yōu)于FedAvg、FedProx、FedDyn 和FedMD 算法。
FedGT 算法為聯(lián)邦學(xué)習(xí)的研究提出了一個(gè)新的方向:除了傳輸目標(biāo)訓(xùn)練模型的參數(shù),還可以傳輸其他信息(例如生成模型)來加快模型訓(xùn)練,減少通信輪數(shù),這是一個(gè)新穎且具有挑戰(zhàn)的研究方向。我們的下一步研究工作是在更廣泛的數(shù)據(jù)集上檢驗(yàn)FedGT 算法,這樣才能真正捕捉到聯(lián)邦學(xué)習(xí)真實(shí)場景下的大規(guī)模分布的復(fù)雜性。