联邦学习系列
引言
本小节使用 Flower 和 PyTorch 运行联邦学习项目,该项目使用到的模型和数据只是一个示例但是可以扩展到大多数模型和数据集像:TensorFlow、JAX、Hugging Phase Transformers等。本小节所用到的资源都放在这里。
在一个基本的联邦学习系统中,你有一个服务器和多个客户端。服务器本身通常没有任何数据,它可以用一些用于评估全局模型的数据,但是在普通联邦学习中它没有任何数据,客户端拥有实际训练数据。服务端和客户端都拥有一个模型副本,服务端的模型被称为全局模型,客户端的模型被称为局部模型。
开始时,服务端初始化全局模型参数并发送给客户端,客户端在本地数据集上训练模型,但是并不训练到收敛,而是每训练一个周期都发送自身的模型给服务端,服务端对所有模型进行聚合。最常见的模型聚合算法是联邦平均算法(根据每个特定客户端上进行训练的训练示例数量进行加权平均)。联邦学习是一个迭代的过程,他会一直重复上述操作直到收敛或者达到指定周期。
数据处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 导入必要的包
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import ndarrays_to_parameters, Context
from flwr.server import ServerApp, ServerConfig
from flwr.server import ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from utils2 import *
# 预处理训练数据
trainset = datasets.MNIST(
"./MNIST_data/", download=True, train=True, transform=transform
)
total_length = len(trainset)
split_size = total_length // 3
torch.manual_seed(42)
part1, part2, part3 = random_split(trainset, [split_size] * 3)
part1 = exclude_digits(part1, excluded_digits=[1, 3, 7])
part2 = exclude_digits(part2, excluded_digits=[2, 5, 8])
part3 = exclude_digits(part3, excluded_digits=[4, 6, 9])
train_sets = [part1, part2, part3]
# 预处理测试数据
testset = datasets.MNIST(
"./MNIST_data/", download=True, train=False, transform=transform
)
print("Number of examples in `testset`:", len(testset))
testset_137 = include_digits(testset, [1, 3, 7])
testset_258 = include_digits(testset, [2, 5, 8])
testset_469 = include_digits(testset, [4, 6, 9])
用户端和服务端配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Sets the parameters of the model
def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict(
{k: torch.tensor(v) for k, v in params_dict}
)
# load_state_dict 是 torch 中的方法,作用是将预训练的参数权重加载到新的模型之中
net.load_state_dict(state_dict, strict=True)
# Retrieves the parameters from the model
def get_weights(net):
ndarrays = [
val.cpu().numpy() for _, val in net.state_dict().items()
]
return ndarrays
# 定义一个 Flower 的客户端类
class FlowerClient(NumPyClient):
def __init__(self, net, trainset, testset):
self.net = net
self.trainset = trainset
self.testset = testset
# Train the model
def fit(self, parameters, config):
set_weights(self.net, parameters)
train_model(self.net, self.trainset)
return get_weights(self.net), len(self.trainset), {}
# Test the model
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
set_weights(self.net, parameters)
loss, accuracy = evaluate_model(self.net, self.testset)
return loss, len(self.testset), {"accuracy": accuracy}
# 定义一个函数以便于在需要时生成 FlowerClient 对象
def client_fn(context: Context) -> Client:
net = SimpleModel()
partition_id = int(context.node_config["partition-id"])
client_train = train_sets[int(partition_id)]
client_test = testset
return FlowerClient(net, client_train, client_test).to_client()
# 生成一个 FlowerClient 对象
client = ClientApp(client_fn)
# 定义一个函数用于评估模型的准确性
def evaluate(server_round, parameters, config):
net = SimpleModel()
set_weights(net, parameters)
_, accuracy = evaluate_model(net, testset)
_, accuracy137 = evaluate_model(net, testset_137)
_, accuracy258 = evaluate_model(net, testset_258)
_, accuracy469 = evaluate_model(net, testset_469)
log(INFO, "test accuracy on all digits: %.4f", accuracy)
log(INFO, "test accuracy on [1,3,7]: %.4f", accuracy137)
log(INFO, "test accuracy on [2,5,8]: %.4f", accuracy258)
log(INFO, "test accuracy on [4,6,9]: %.4f", accuracy469)
if server_round == 3:
cm = compute_confusion_matrix(net, testset)
plot_confusion_matrix(cm, "Final Global Model")
# 定义一个神经网络获取初始参数赋值给 params
net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))
# 定义一个函数以便于在需要时生成 ServerAppComponents 对象
def server_fn(context: Context):
# 聚合策略为 FedAvg
strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=0.0,
initial_parameters=params,
evaluate_fn=evaluate,
)
config=ServerConfig(num_rounds=3)
return ServerAppComponents(
strategy=strategy,
config=config,
)
# 创建一个 ServerApp 实例
server = ServerApp(server_fn=server_fn)
开始训练
1
2
3
4
5
6
7
8
9
10
# Initiate the simulation passing the server and client apps
# Specify the number of super nodes that will be selected on every round
# Flower 称客户端为 super nodes 以强调这些节点在联邦学习中的重要性。
run_simulation(
server_app=server,
client_app=client,
# 客户端节点的个数
num_supernodes=3,
backend_config=backend_setup,
)
训练得到的模型测试结果如图。