联邦学习:按Dirichlet分布划分Non-IID样本
Python微信訂餐小程序課程視頻
https://edu.csdn.net/course/detail/36074
Python實戰量化交易理財系統
https://edu.csdn.net/course/detail/35475
我們在《Python中的隨機采樣和概率分布(二)》介紹了如何用Python現有的庫對一個概率分布進行采樣,其中的Dirichlet分布大家一定不會感到陌生。該分布的概率密度函數為
P(x;α)∝k∏i=1xαi?1ix=(x1,x2,…,xk),xi>0,k∑i=1xi=1α=(α1,α2,…,αk).αi>0P(\bm{x}; \bm{\alpha}) \propto \prod_{i=1}^{k} x_{i}^{\alpha_{i}-1} \
\bm{x}=(x_1,x_2,…,x_k),\quad x_i > 0 , \quad \sum_{i=1}^k x_i = 1\
\bm{\alpha} = (\alpha_1,\alpha_2,…, \alpha_k). \quad \alpha_i > 0
其中α\bm{\alpha}為參數。
我們在聯邦學習中,經常會假設不同client間的數據集不滿足獨立同分布(Non-IID)。那么我們如何將一個現有的數據集按照Non-IID劃分呢?我們知道帶標簽樣本的生成分布看可以表示為p(x,y)p(\bm{x}, y),我們進一步將其寫作p(x,y)=p(x|y)p(y)p(\bm{x}, y)=p(\bm{x}|y)p(y)。其中如果要估計p(x|y)p(\bm{x}|y)的計算開銷非常大,但估計p(y)p(y)的計算開銷就很小。所有我們按照樣本的標簽分布來對樣本進行Non-IID劃分是一個非常高效、簡便的做法。
總而言之,我們采取的算法思路是盡量讓每個client上的樣本標簽分布不同。我們設有KK個類別標簽,NN個client,每個類別標簽的樣本需要按照不同的比例劃分在不同的client上。我們設矩陣X∈RK?N\bm{X}\in \mathbb{R}^{K*N}為類別標簽分布矩陣,其行向量xk∈RN\bm{x}_k\in \mathbb{R}^N表示類別kk在不同client上的概率分布向量(每一維表示kk類別的樣本劃分到不同client上的比例),該隨機向量就采樣自Dirichlet分布。
據此,我們可以寫出以下的劃分算法:
import numpy as np
np.random.seed(42)
def split\_noniid(train\_labels, alpha, n\_clients):'''參數為alpha的Dirichlet分布將數據索引劃分為n\_clients個子集'''n_classes = train_labels.max()+1label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)# (K, N)的類別標簽分布矩陣X,記錄每個client占有每個類別的多少class_idcs = [np.argwhere(train_labels==y).flatten() for y in range(n_classes)]# 記錄每個K個類別對應的樣本下標client_idcs = [[] for _ in range(n_clients)]# 記錄N個client分別對應樣本集合的索引for c, fracs in zip(class_idcs, label_distribution):# np.split按照比例將類別為k的樣本劃分為了N個子集# for i, idcs 為遍歷第i個client對應樣本集合的索引for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):client_idcs[i] += [idcs]client_idcs = [np.concatenate(idcs) for idcs in client_idcs]return client_idcs
加下來我們在EMNIST數據集上調用該函數進行測試,并進行可視化呈現。我們設client數量N=10N=10,Dirichlet概率分布的參數向量α\bm{\alpha}滿足αi=1.0,?i=1,2,…N\alpha_i=1.0,\space i=1,2,…N:
import torch
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plttorch.manual_seed(42)if __name__ == "\_\_main\_\_":N_CLIENTS = 10 DIRICHLET_ALPHA = 1.0train_data = datasets.EMNIST(root=".", split="byclass", download=True, train=True)test_data = datasets.EMNIST(root=".", split="byclass", download=True, train=False)n_channels = 1input_sz, num_cls = train_data.data[0].shape[0], len(train_data.classes)train_labels = np.array(train_data.targets)# 我們讓每個client不同label的樣本數量不同,以此做到Non-IID劃分client_idcs = split_noniid(train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)# 展示不同client的不同label的數據分布plt.figure(figsize=(20,3))plt.hist([train_labels[idc]for idc in client_idcs], stacked=True, bins=np.arange(min(train_labels)-0.5, max(train_labels) + 1.5, 1),label=["Client {}".format(i) for i in range(N_CLIENTS)], rwidth=0.5)plt.xticks(np.arange(num_cls), train_data.classes)plt.legend()plt.show()
最終的可視化結果如下:
可以看到,62個類別標簽在不同client上的分布確實不同,證明我們的樣本劃分算法是有效的。
總結
以上是生活随笔為你收集整理的联邦学习:按Dirichlet分布划分Non-IID样本的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 基环树
- 下一篇: P5049 [NOIP2018 提高组]