蒸馏学习流程

Posted by kalos Aner on February 22, 2025

一、选择并预训练教师模型

  • 目标 在大规模数据集上训练一个性能优异的教师模型,该模型通常拥有较深的网络结构和大量参数,因此具有很强的表达能力和较高的准确率。
  • 训练过程 使用标准的监督学习方法(如交叉熵损失)进行训练,确保教师模型能够对输入数据做出准确的预测。训练完成后,教师模型能够为每个输入样本生成预测概率分布,这些“软标签”包含了类别间的细微关系。

二、生成软标签和温度调节

软标签 教师模型在进行预测时,除了给出最终的类别预测外,还会输出每个类别的概率。与传统的硬标签(0或1)不同,这些概率分布(软标签)提供了更多信息,比如某一类别与其他类别的相似度。

温度参数 在生成软标签时,通常会引入一个温度参数 TTT 来“软化”输出概率分布。较高的温度会使得概率分布更平滑,从而让学生模型更容易捕捉到教师模型的知识细节。公式中常用的做法是将 logits 除以温度 TTT 后再计算 softmax 分布。

三、学生模型的设计与初始化

  • 模型选择 根据无人机的计算资源和响应速度要求,设计一个参数更少、计算更高效的学生模型。虽然结构上比教师模型简单,但目标是通过蒸馏过程尽可能保留教师模型中蕴含的重要特征信息。
  • 初始化策略 学生模型可以从头开始训练,也可以通过预训练或者部分迁移初始化,这有助于后续的蒸馏训练更快收敛。

四、学生模型的训练过程

  • 损失函数构成 学生模型的训练通常包含两部分损失:

    1. 传统监督损失 利用真实标签(硬标签)计算标准交叉熵损失,确保学生模型能正确分类。
    2. 蒸馏损失 使用教师模型生成的软标签,通过计算学生模型输出与教师软标签之间的差异(例如使用 KL 散度或均方误差),来传递教师模型的知识。

    综合损失可以写为:

    $\mathcal{L} = \alpha \cdot \mathcal{L}{\text{CE}}(y, \text{Student}(x)) + (1-\alpha) \cdot \mathcal{L}{\text{KD}}(\text{Teacher}(x, T), \text{Student}(x, T))$

    其中,$\alpha $控制两部分损失的权重,而 $T$ 是温度参数。

  • 反向传播与参数更新 将综合损失反向传播,更新学生模型的参数,使其在学习真实标签信息的同时,也能吸收教师模型在软标签中传递的“暗知识”。

  • 动态调整 在训练过程中,可以根据学生模型的表现动态调整温度参数和损失权重,以更好地平衡两种目标,最终达到最佳性能。

五、整体流程总结

  1. 训练教师模型:在大数据集上训练高精度模型,生成包含丰富信息的软标签。
  2. 软标签生成:通过教师模型输出软标签,并使用温度参数平滑分布。
  3. 设计学生模型:构建轻量化模型,适应实际部署场景。
  4. 联合损失训练:结合真实标签和软标签,通过综合损失函数训练学生模型。
  5. 优化与微调:根据实验结果调整超参数,确保学生模型既轻量又能保持较高性能。

通过以上步骤,学生模型不仅能够在计算资源有限的无人机上高效运行,还能在一定程度上继承教师模型的知识,提高整体性能。

六、蒸馏学习扩展

进阶蒸馏方法

  • 离线蒸馏:教师模型固定,仅用于生成软标签(最常用)。
  • 在线蒸馏:教师与学生模型同步更新,实现动态知识传递。
  • 多教师蒸馏:集成多个教师模型的输出,提升知识多样性。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# 设置超参数
BATCH_SIZE = 128
EPOCHS = 5
TEMPERATURE = 4.0  # 温度参数
ALPHA = 0.5  # 交叉熵损失和知识蒸馏损失的权重
LEARNING_RATE = 0.01

# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 定义教师模型
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义学生模型
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 计算蒸馏损失
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # 计算教师模型和学生模型的 softmax 预测(使用温度参数)
    soft_targets = F.log_softmax(teacher_logits / temperature, dim=1)
    soft_predictions = F.log_softmax(student_logits / temperature, dim=1)

    # 计算 KL 散度损失
    kl_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean') * (temperature ** 2)

    # 计算标准交叉熵损失
    ce_loss = F.cross_entropy(student_logits, true_labels)

    # 组合损失
    return alpha * ce_loss + (1 - alpha) * kl_loss

# 训练教师模型
def train_teacher():
    teacher = TeacherModel().to(device)
    optimizer = optim.Adam(teacher.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        teacher.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = teacher(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")
    
    torch.save(teacher.state_dict(), "teacher_model.pth")
    print("教师模型训练完成并已保存!")
    return teacher

# 训练学生模型
def train_student(teacher):
    student = StudentModel().to(device)
    teacher.eval()  # 设置教师模型为评估模式(不更新梯度)
    optimizer = optim.Adam(student.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        student.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            student_outputs = student(images)
            teacher_outputs = teacher(images).detach()  # 关闭教师模型的梯度计算
            
            loss = distillation_loss(student_outputs, teacher_outputs, labels, TEMPERATURE, ALPHA)
            loss.backward()
            optimizer.step()
        
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")

    torch.save(student.state_dict(), "student_model.pth")
    print("学生模型训练完成并已保存!")
    return student

# 评估模型
def evaluate_model(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'模型准确率: {100 * correct / total:.2f}%')

# 运行
device = torch.device("cuda"if torch.cuda.is_available() else"cpu")

# 训练教师模型
teacher_model = train_teacher()

# 训练学生模型(使用知识蒸馏)
student_model = train_student(teacher_model)

# 评估教师和学生模型
print("\n教师模型测试集准确率:")
evaluate_model(teacher_model)
print("\n学生模型测试集准确率:")
evaluate_model(student_model)