联邦学习系列
引言
大量的标注数据会极大地提升模型训练的性能,但是公共的的可用标注数据会随着时间逐渐被耗尽,然而还有很多私人数据没有被用来训练模型,因此联邦学习有着很重要的作用。
utils1 库是一个自定义的包,包括这个包之内的所有资源都放在这里。
数据处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from utils1 import *
# 下载数据集
trainset = datasets.MNIST(
"./MNIST_data/", download=True, train=True, transform=transform
)
# 使用 torch 把数据集分成 3 个部分
total_length = len(trainset)
split_size = total_length // 3
torch.manual_seed(42) # 设置随机种子
part1, part2, part3 = random_split(trainset, [split_size] * 3)
# 由于联邦学习用到的数据是分布式不均匀的,所以每个部分剔除一部分数据用来模拟真实数据分布
part1 = exclude_digits(part1, excluded_digits=[1, 3, 7])
part2 = exclude_digits(part2, excluded_digits=[2, 5, 8])
part3 = exclude_digits(part3, excluded_digits=[4, 6, 9])
模型训练
SimpleModel
是一个类,是pytorch 中实现的仅有两个全连接层的神经网络。这里模型选择并不重要,只要适用于MNIST数据集就行。
1
2
3
4
5
6
7
8
model1 = SimpleModel()
train_model(model1, part1)
model2 = SimpleModel()
train_model(model2, part2)
model3 = SimpleModel()
train_model(model3, part3)
模型测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 选择同样的数据集,把 train 设置为 false
testset = datasets.MNIST(
"./MNIST_data/", download=True, train=False, transform=transform
)
# 表示不同的数据集
testset_137 = include_digits(testset, included_digits=[1, 3, 7])
testset_258 = include_digits(testset, included_digits=[2, 5, 8])
testset_469 = include_digits(testset, included_digits=[4, 6, 9])
# evaluate_model 接收一个模型实例和数据集为输入,返回准确度
_, accuracy1 = evaluate_model(model1, testset)
_, accuracy1_on_137 = evaluate_model(model1, testset_137)
print(
f"Model 1-> Test Accuracy on all digits: {accuracy1:.4f}, "
f"Test Accuracy on [1,3,7]: {accuracy1_on_137:.4f}"
)
_, accuracy2 = evaluate_model(model2, testset)
_, accuracy2_on_258 = evaluate_model(model2, testset_258)
print(
f"Model 2-> Test Accuracy on all digits: {accuracy2:.4f}, "
f"Test Accuracy on [2,5,8]: {accuracy2_on_258:.4f}"
)
_, accuracy3 = evaluate_model(model3, testset)
_, accuracy3_on_469 = evaluate_model(model3, testset_469)
print(
f"Model 3-> Test Accuracy on all digits: {accuracy3:.4f}, "
f"Test Accuracy on [4,6,9]: {accuracy3_on_469:.4f}"
)
# 生成 confusion matrix,矩阵表示真实的标签和预测的标签的数据量,对于缺失的标签model会不进行预测
confusion_matrix_model1_all = compute_confusion_matrix(model1, testset)
confusion_matrix_model2_all = compute_confusion_matrix(model2, testset)
confusion_matrix_model3_all = compute_confusion_matrix(model3, testset)
plot_confusion_matrix(confusion_matrix_model1_all, "model 1")
plot_confusion_matrix(confusion_matrix_model2_all, "model 2")
plot_confusion_matrix(confusion_matrix_model3_all, "model 3")
训练后的模型测试结果如上图。