联邦学习之自定义化

Posted by kalos Aner on December 4, 2024

联邦学习系列

联邦学习之基础用法

联邦学习之分布训练

联邦学习之自定义化

联邦学习之隐私增强

联邦学习之带宽需求

引言

本节介绍自定义联邦学习。当客户端很多时每一次聚合的收益会随着选择的客户端数量递减,所以不需要选择所有的客户端进行聚合,只需要选择客户端的一个子集。在拥有百万的客户端的场景下只需要选择数百个客户端,或者最多数千个。根据任务的不同,可以有不同的策略。常见的选择策略是随机选择。另外还可以进行顺序训练,每个客户端训练好模型发送给下一个客户端进行训练。除此之外还可以定义客户端需要知道的其他超参数,例如它应该训练多长时间。本小节所用到的资源都放在这里

数据准备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from utils3 import *
# flwr_datasets 提供了一个名为 FederatedDataset 的类,这个类将许多现有数据集(例如MNIST)进行了分区,允许用户为每个客户端生成小型的训练和测试集。load_data 以一个分区 ID 作为输入,该 ID 指定要加载的数据集分区。此处使用的数据集是 MNIST,被分成了 5 个分区,每个分区按照 80 比 20 的比例分为训练集和测试集,这是通过 train_test_split 实现的。
def load_data(partition_id):
    fds = FederatedDataset(dataset="mnist", partitioners={"train": 5})
    partition = fds.load_partition(partition_id)

    traintest = partition.train_test_split(test_size=0.2, seed=42)
    traintest = traintest.with_transform(normalize)
    trainset, testset = traintest["train"], traintest["test"]

    trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = DataLoader(testset, batch_size=64)
    return trainloader, testloader

用户端和服务端配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# 接受服务端的轮数,返回一个字典数据, 表示客户端应该在本地训练的轮次数。客户端可以动态的改变本地轮次数,如下代码。
def fit_config(server_round: int):
    config_dict = {
        "local_epochs": 2 if server_round < 3 else 5,
    }
    return config_dict

net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))

def server_fn(context: Context):
    strategy = FedAvg(
        min_fit_clients=5,
        fraction_evaluate=0.0, # 设置为 0,表示不进行客户端评估。
        initial_parameters=params,
        on_fit_config_fn=fit_config,  # <- NEW 把 fit_config 赋值给 on_fit_config_fn
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

server = ServerApp(server_fn=server_fn)

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def fit(self, parameters, config):
        set_weights(self.net, parameters)

        epochs = config["local_epochs"]
        log(INFO, f"client trains for {epochs} epochs")
        train_model(self.net, self.trainloader, epochs)

        return get_weights(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        loss, accuracy = evaluate_model(self.net, self.testloader)
        return loss, len(self.testloader), {"accuracy": accuracy}
    
def client_fn(context: Context) -> Client:
    net = SimpleModel()
    partition_id = int(context.node_config["partition-id"])
    trainloader, testloader = load_data(partition_id=partition_id)
    return FlowerClient(net, trainloader, testloader).to_client()


client = ClientApp(client_fn)

进行训练

1
2
3
4
5
run_simulation(server_app=server,
               client_app=client,
               num_supernodes=5,
               backend_config=backend_setup
               )

注:训练的时候一定要给够内存和可用的网络,不然有可能会失败。