联邦学习开山之作:Communication-Efficient Learning of Deep Networks from Decentralized Data 带你走进最初的联邦学习 论文精读
原文鏈接:Communication-Efficient Learning of Deep Networks from Decentralized Data (mlr.press)
該論文是最早提出聯邦學習的論文,作者結合背景提出了聯邦平均的算法,并作了相應驗證實驗。
ABS
隨著移動設備的用戶增加,產生了大量的分散數據。這些分散數據通常是涉及用戶隱私的,所以想要將數據集中起來進行訓練是不現實的,作者提出一種能夠在各移動設備上訓練一個聯合模型的方法——Federated Learning。該方法同時能夠提高通信效率,相比于同步的隨機梯度下降(synchronized stochastic gradient descent),通信耗時下降到原來的110\frac 1 {10}101?至1100\frac 1 {100}1001?。
1 INTRO
手機和平板用戶的增多以及用戶隨身攜帶意味著會產生大量的數據,這些數據有著巨大的價值,可以通過對這些數據進行訓練來提升用戶的體驗,但是這些數據是設計用戶隱私的,將數據直接收集起來是不合法的。
聯邦學習對模型訓練與整體數據的需求進行解耦(各方獨自訓練),這樣做一方面能夠保護用戶隱私,另一方面攻擊方想要獲得數據只能攻擊各個用戶,而不能直接攻擊中心服務器(中心服務器沒有用戶的數據)。
本文的貢獻:
- 指明了對于分散數據進行機器學習的方向(即提出了聯邦學習);
- 一種聯邦學習的方法(該方法是各個客戶端進行SGD和服務器進行模型平均的有機組合);
- 對提出的方法進行廣泛的具有實踐意義的評價。
聯邦學習的特點:
- 數據大都來自于真實數據訓練效果更加貼合實際;
- 數據高度敏感且相對于單個用戶來說,數據量非常大;
- 對于有監督學習,可以通過與用戶的互動輕松對數據進行標號。
聯邦學習對于隱私性的保護:
- 會進行通信的數據只有需要的更新,這保證了用戶數據的安全;
- 更新數據不需要保存,一旦更新成功,更新數據將被丟失;
- 通過更新數據對原始數據的破解幾乎不可能。
聯邦學習與分布式學習有著幾個顯著的不同:
- 數據分布非獨立同分布:不同的用戶有著不同的行為;
- 數據分布不平衡:指某些參與者的數據可能很多,而某些參與者數據可能很少;
- 大量的參與者:一個軟件的用戶可能非常多(例如某款輸入法);
- 受限的通信:參與者的信號可能非常差,甚至出現離線的情況;
聯邦學習需要處理以下問題:
- 各個參與方的數據可能會發生改變(例如刪除、添加、編輯照片);
- 參與方的數據分布非常復雜(不同的群體的手機使用情況差異可能會非常大)。
本論文的實驗環境是一個可控的環境,主要用于解決非獨立數據分布和不平衡數據分布的問題:
- 實驗中固定有KKK個參與者,參與者的數據集固定不發生更改;
- 每一輪開始時,選擇C?K(0≤C≤1)C*K(0\le C\le1)C?K(0≤C≤1)個參與者(實驗發現當參與者的數量超過某個值是效果會出現下降,所以只選擇部分的參與者),服務器將最新的參數下發給選中的參與者,參與者進行聯邦學習(包含訓練、聯邦聚合等);
通常情況下(對于數據集中分布),我們需要:
minw∈Rdf(w)=1n∑i=1nfi(w)\underset {w\in R^d}{min} f(w) = \frac 1 n\sum_{i=1}^nf_i(w) w∈Rdmin?f(w)=n1?i=1∑n?fi?(w)
上式中的fi(w)f_i(w)fi?(w)代表樣本iii的損失。
在聯邦學習中,我們需要做出一定的變形:
f(w)=∑k=1KnknFk(w)Fk(w)=1nk∑i∈Pkfi(w)\begin{aligned} &f(w) = \sum_{k=1}^K\frac {n_k} nF_k(w)\\ &F_k(w)=\frac 1 {n_k}\sum_{i \in P_k}f_i(w) \end{aligned} ?f(w)=k=1∑K?nnk??Fk?(w)Fk?(w)=nk?1?i∈Pk?∑?fi?(w)?
上式中的KKK代表參與更新的參與方個數,nnn代表總共數據個數,nkn_knk?代表第kkk個參與方擁有的數據個數,Fk(w)F_k(w)Fk?(w)代表第kkk個參與方的平均損失,fi(w)f_i(w)fi?(w)代表樣本iii的損失,PkP_kPk?代表第kkk個參與方的數據索引集合。
在滿足獨立同分布的數據中(傳統的分布式機器學習中,數據統一隨機的下放至每個參與方),E[Fk(w)]=f(w)E[F_k(w)]=f(w)E[Fk?(w)]=f(w)(也就是樣本平均的期望是與總體期望相同),但是在聯邦學習中數據往往不是獨立同分布的,這會導致,Fk(w)F_k(w)Fk?(w)并不能很好的對f(w)f(w)f(w)進行近似。
聯邦學習與傳統數據中心計算的不同:
- 傳統的數據中心計算,往往通信的消耗是相對較小的,計算的時間是相對較大的,聯邦學習正好相反(用戶可能只會在特定的時間才回進行上傳,例如睡覺時,而由于用戶的處理器往往不會太差,計算的耗費相對就會較小);
有兩種方法能夠緩解聯邦學習中通信耗時的問題(核心都是讓用戶進行大量的計算,這樣能減少通信時間的占比):
- 增加并行度,讓更多的參與者進行計算;
- 增加計算度,讓一個參與者進行更多的計算。
相關工作介紹:
- 傳統的分布式學習考慮的是平衡的分布(各個參與方的計算量與各自的計算能力相匹配)和獨立同分布(各方的數據是獨立同分布的);
- 傳統的分布式學習只會進行一次集中更新,已經被證明,這樣訓練出來的模型在最壞情況下可能會比在單個參與方訓練的模型效果差。
2 The Federated Averaging Algorithm
Federated SGD:當選擇C=1C=1C=1并且使用SGD進行聯邦學習,本文將這種基準定義為Federated SGD,一種簡單的實現如下:
- 選取C=1C=1C=1,即每次所有結點都參與計算;
- 對于參與方kkk,計算gk=?Fk(wt)g_k=\nabla F_k(w_t)gk?=?Fk?(wt?),并將計算結果發送至服務器;
- 服務器對各方的梯度進行聚合:wt+1=wt?η∑k=1Knkngkw_{t+1}=w_t-\eta \sum_{k=1}^K \frac {n_k}n g_kwt+1?=wt??η∑k=1K?nnk??gk?;
同時為了提升各個參與方的計算量,作者提出Federated Averaging:
- 對于參與方kkk,計算wt+1k=wtk?ηgkw_{t+1}^k = w^k_t-\eta g_kwt+1k?=wtk??ηgk?,同時將計算結果發送至服務器;
- 服務器對參數進行聚合:wt+1=∑k=1Knknwt+1kw_{t+1}=\sum_{k=1}^K\frac {n_k} n w_{t+1}^kwt+1?=∑k=1K?nnk??wt+1k?。
這種方法的好處在于:參與方在發送參數之前可以進行多次的參數計算,這也就增加了參與方的計算量,Federated SGD是一種同步算法(需要等待所有方計算完梯度),而Federated Averaging并不算完全的同步算法(可以根據情況調整每一方進行計算的次數,但是還是會有同步操作)。
Federated Averaging的偽代碼如下:
Federated Averaging中參與方kkk的參數更新次數為:μk=EnkB\mu_k=E\frac {n_k} Bμk?=EBnk??。
作者對于不同的聚和權重以及參數的初始化方法進行了實驗:
- 作者以兩個參與方為例,分別進行了幾種實驗;
- 第一種實驗對兩個參與方采用不同的隨機種子對參數初始化,第二種實驗對兩個參與方采用相同的隨機種子對參數進行初始化;
- 聚合采用wt+1=θwt1+(1?θ)wt2,θ∈[?0.2,1.2]w_{t+1} = \theta w_t^1 + (1-\theta) w_t^2,\ \theta\in[-0.2,1.2]wt+1?=θwt1?+(1?θ)wt2?,?θ∈[?0.2,1.2];
- 在MNIST上進行訓練。
實驗結果如下圖:
(左圖為不同的隨機種子進行初始化,右圖為相同的隨機種子進行初始化;橫軸代表θ\thetaθ的取值,縱軸代表損失,注意兩圖縱軸的損失的刻度不同)
實驗結果表明:使用相同的隨機種子進行初始化,同時使用最符合常識的聚合方法(θ=12\theta=\frac 1 2θ=21?)效果是最好的。
3 Experimental Results
3.1 Datasets
實驗數據集:
- MNIST:包含手寫的阿拉伯數字0-9圖片(每張圖片只有一個數字),目標是對輸入圖片中的數字進行識別;
- 從The complete Works of William Shakespeare構造(輸入是句子)一個數據集用于語言模型的訓練記作Play&Role,目標是對輸入進行下一個詞語的預測。
對于上述的兩組數據集會分別產生四類數據集應用于不同的實驗:
- MNIST + IDD + balanced:將MNIST隨機地劃分成100100100份(實驗中有100100100個參與者),每份有600600600個樣例;
- MNIST + Non-IDD + balanced:將MNIST的圖片按照包含的數字進行排序,然后將圖片等間距的分成100100100份,這意味著,每個參與者的數據集最多只有兩種數字;
- Play&Role + Non-IDD + unbalanced:記錄每一場臺詞超過兩句的人物數量(一共有114611461146個),人物數量即參與方的個數,每個參與方擁有的數據即某個人物的在某一場的臺詞(這意味著有的參與者的數據量非常大,有的非常小,這顯然是unbalanced,而數據同樣顯然是Non-IDD);
- Play&Role + IDD + balanced:將上述所有的臺詞組合起來,隨機的下發給114611461146個參與者。
3.2 Models
實驗用到的模型:
- MNIST 2NN:一個有著兩個隱藏層的多層感知機,用于MNIST數據集的訓練;
- CNN:一個使用卷積窗口大小為555的卷積神經網絡,用于MNIST數據集的訓練;
- LSTM:長短期記憶,用于Play&Role模型的訓練
3.3 Experiment 1
實驗一:增加并行度實驗,所謂增加并行度就是增加參與計算的結點數量,通過增加CCC實現,下表展示了在MNIST上不同的CCC對于收斂速度的影響:
(表數字代表到達測試精度99%99\%99%(CNN)或者97%97\%97%(2NN)所需要的輪數(輪數是通過線性插值的方法計算出來),每個表的第一行代表基準,C=0C=0C=0代表每輪只選擇一個結點進行計算,BBB代表每個結點的批量大小(如Algorithm 1描述),B=∞B=\inftyB=∞代表每個節點迭代時選取自己擁有的所有數據,EEE代表結點的Epoch,圖中的“—”代表在規定時間內沒有達到指定精度)
當參與計算的結點增大到一定值時,收斂速度有可能會有所下降,所以一般不會讓所有的結點都參與計算。
3.4 Experiment 2
實驗二:增加每個節點的計算量,通過改變μ\muμ(結點參數更新次數)來改變節點的計算量,而μ=EnkB\mu = E \frac {n_k} Bμ=EBnk??(見Algorithm 1)所以通過減少BBB或者增加EEE可以增加結點的計算量。實驗的結果如下圖所示(本次試驗中固定C=0.1C=0.1C=0.1):
可以看到當單個節點的計算量增加時,整體收斂所需要的輪數也減少了。
實驗二中出現了一個有趣的現象:在CNN中,當E=1,B=∞E=1,B=\inftyE=1,B=∞時,FedSGD算法在120012001200輪之后達到了99.22%99.22\%99.22%準確率,而FedAVG在B=10,E=20B=10,E=20B=10,E=20時,300輪就達到了99.44%99.44\%99.44%的準確度。這一點是不符合常識的,因為FedSGD效果應該等價于將數據全部匯總進行訓練,而作者給出了這種現象的解釋:FedAVG丟失整體性類似發揮了dropout的效果。
3.5 Experiment 3
實驗三:參與方極致利用自己數據,在該實驗中,參與方會盡可能進行計算(EEE增大),理論上如果EEE趨向于無窮大,那么初始化參數的影響就會被忽略(只要他們能夠到達同一個局部最優解),實驗結果如下圖所示:
從上圖中可以看到當EEE增大時,損失反而上升了,同時出現了巨大的抖動,作者給出的原因是:實驗中學習率是固定的,當EEE增大時,在后續進行訓練時,學習率會變得相對較大,所以出現了抖動。
3.6 Experiment 4
實驗四:CIFAR-10上的實驗。實驗結果如下圖所示:
3.7 Experiment 5
實驗五:大規模的長短期記憶實驗,即對Play&Role數據集進行訓練,訓練結果如下圖:
4 Conclusions and Future Work
結論:實驗結果表名聯邦學習可以應用到實際工程中。
未來工作:未來工作的方向將會包括安全多方計算,對隱私保護的更強的保證。
總結
以上是生活随笔為你收集整理的联邦学习开山之作:Communication-Efficient Learning of Deep Networks from Decentralized Data 带你走进最初的联邦学习 论文精读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Spring Boot 实现接口的各种参
- 下一篇: 通讯录管理系统(C++)