联邦学习系列
引言
本节介绍自定义联邦学习。当客户端很多时每一次聚合的收益会随着选择的客户端数量递减,所以不需要选择所有的客户端进行聚合,只需要选择客户端的一个子集。在拥有百万的客户端的场景下只需要选择数百个客户端,或者最多数千个。根据任务的不同,可以有不同的策略。常见的选择策略是随机选择。另外还可以进行顺序训练,每个客户端训练好模型发送给下一个客户端进行训练。除此之外还可以定义客户端需要知道的其他超参数,例如它应该训练多长时间。本小节所用到的资源都放在这里
数据准备
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from utils3 import *
# flwr_datasets 提供了一个名为 FederatedDataset 的类,这个类将许多现有数据集(例如MNIST)进行了分区,允许用户为每个客户端生成小型的训练和测试集。load_data 以一个分区 ID 作为输入,该 ID 指定要加载的数据集分区。此处使用的数据集是 MNIST,被分成了 5 个分区,每个分区按照 80 比 20 的比例分为训练集和测试集,这是通过 train_test_split 实现的。
def load_data(partition_id):
fds = FederatedDataset(dataset="mnist", partitioners={"train": 5})
partition = fds.load_partition(partition_id)
traintest = partition.train_test_split(test_size=0.2, seed=42)
traintest = traintest.with_transform(normalize)
trainset, testset = traintest["train"], traintest["test"]
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64)
return trainloader, testloader
用户端和服务端配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# 接受服务端的轮数,返回一个字典数据, 表示客户端应该在本地训练的轮次数。客户端可以动态的改变本地轮次数,如下代码。
def fit_config(server_round: int):
config_dict = {
"local_epochs": 2 if server_round < 3 else 5,
}
return config_dict
net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))
def server_fn(context: Context):
strategy = FedAvg(
min_fit_clients=5,
fraction_evaluate=0.0, # 设置为 0,表示不进行客户端评估。
initial_parameters=params,
on_fit_config_fn=fit_config, # <- NEW 把 fit_config 赋值给 on_fit_config_fn
)
config=ServerConfig(num_rounds=3)
return ServerAppComponents(
strategy=strategy,
config=config,
)
server = ServerApp(server_fn=server_fn)
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, testloader):
self.net = net
self.trainloader = trainloader
self.testloader = testloader
def fit(self, parameters, config):
set_weights(self.net, parameters)
epochs = config["local_epochs"]
log(INFO, f"client trains for {epochs} epochs")
train_model(self.net, self.trainloader, epochs)
return get_weights(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = evaluate_model(self.net, self.testloader)
return loss, len(self.testloader), {"accuracy": accuracy}
def client_fn(context: Context) -> Client:
net = SimpleModel()
partition_id = int(context.node_config["partition-id"])
trainloader, testloader = load_data(partition_id=partition_id)
return FlowerClient(net, trainloader, testloader).to_client()
client = ClientApp(client_fn)
进行训练
1
2
3
4
5
run_simulation(server_app=server,
client_app=client,
num_supernodes=5,
backend_config=backend_setup
)
注:训练的时候一定要给够内存和可用的网络,不然有可能会失败。