您的位置:首页 > 新闻 > 资讯 > 安徽工程建设信息网站_今日新游戏开服时间表_怎么自己做个网站_全国分站seo

安徽工程建设信息网站_今日新游戏开服时间表_怎么自己做个网站_全国分站seo

2025/1/10 2:51:24 来源:https://blog.csdn.net/TianxiaZhu824/article/details/144016028  浏览:    关键词:安徽工程建设信息网站_今日新游戏开服时间表_怎么自己做个网站_全国分站seo
安徽工程建设信息网站_今日新游戏开服时间表_怎么自己做个网站_全国分站seo

教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale

知识蒸馏:让AI模型更轻更快

在人工智能快速发展的今天,我们经常需要在资源受限的设备(如手机、IoT设备)上运行AI模型。但这些设备的计算能力和内存都很有限,无法直接运行庞大的AI模型。这就带来了一个重要问题:如何将大模型的能力迁移到小设备上?知识蒸馏(Knowledge Distillation)就是解决这个问题的重要技术之一。

什么是知识蒸馏?

知识蒸馏可以形象地理解为"教师教学生"的过程。大模型(教师模型)将自己学到的"知识"传授给小模型(学生模型),帮助小模型在保持较小体积的同时,获得接近大模型的性能。

这里的"知识"主要包括:

  • 模型的输出概率分布(软标签)
  • 模型中间层的特征
  • 注意力图等信息

知识蒸馏的核心概念

1. 软标签与硬标签

  • 硬标签:传统的分类标签,比如[0,1,0]表示第二类
  • 软标签:模型输出的概率分布,比如[0.1,0.8,0.1],包含更丰富的信息

2. 温度参数

温度参数用于调节概率分布的"软硬程度":

  • 温度越高,分布越平滑
  • 温度越低,分布越接近硬标签
  • 合适的温度可以帮助学生模型更好地学习

下面是一个例子:当输入一张马的图片时,对于未调整温度(默认为1)的 Softmax 输出,正标签的概率接近 1,而负标签的概率接近 0。这种尖锐的分布对学生模型不够友好,因为它只提供了关于正确答案的信息,而忽略了错误答案的信息。即驴比汽车更像马,识别为驴的概率应该大于识别为汽车的概率。而通过温度调整后, 最后得到一个相对平滑的概率分布, 称为 “软标签” (Soft Label)。

知识蒸馏的不同方式

1. 基于输出的蒸馏

直接匹配教师模型和学生模型的输出概率分布。

2. 基于中间层特征的蒸馏

匹配模型中间层的特征,让学生模型学习教师模型的"思考过程"。

3. 基于中间层注意力图的蒸馏

传递模型的注意力机制,帮助学生模型知道"该关注什么"。

4.基于中间层权重的蒸馏

5.基于中间层稀疏模式的蒸馏

6.基于中间相关信息的蒸馏

创新的蒸馏方法

1. 自蒸馏

模型自己当老师,通过多次迭代提升性能,不需要额外的教师模型。

2. 在线蒸馏

教师模型和学生模型同时训练,相互学习,提高效率。

3.结合在线蒸馏和自蒸馏

实际应用场景

知识蒸馏在多个领域都有成功应用:

1. 目标检测

不仅传递分类知识,还包括物体定位信息。

2. 语义分割

通过像素级、成对和整体三个层面的蒸馏提升性能。

3. 生成对抗网络(GAN)

结合蒸馏、重构和对抗性损失实现模型压缩。

4. 自然语言处理

特别强调注意力机制的传递,提升文本处理能力。

网络增强:另一种思路

除了传统的知识蒸馏,网络增强(NetAug)提供了一个新视角:

  • 不是简化大模型,而是增强小模型
  • 将小模型嵌入到大模型中学习
  • 通过多重监督提升性能

代码实践

主要包含:

KD知识蒸馏        DKD解耦知识蒸馏

其区别主要集中在损失函数的不同。

现有的知识蒸馏方法主要关注于中间层的深度特征蒸馏,而对logit蒸馏的重要性认识不足。[DKD]()重新定义了传统的知识蒸馏损失函数,将其分解为目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。

- 目标类知识蒸馏(TCKD):关注于目标类的知识传递。

- 非目标类知识蒸馏(NCKD):关注于非目标类之间的知识传递。

# kd_loss
def loss(logits_student, logits_teacher, temperature):log_pred_student = F.log_softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()loss_kd *= temperature**2return loss_kd
import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):# 使用 _get_gt_mask 和 _get_other_mask 函数创建掩码,分别用于标识真实标签和其他类别。这使得损失计算可以选择性地关注特定类别。gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)# 使用 cat_mask 函数将掩码应用于学生和教师的预测,得到只关注特定类别的输出。pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)# 计算针对真实标签的 KL 散度损失(tckd_loss),并进行温度缩放tckd_loss = (F.kl_div(log_pred_student, pred_teacher, size_average=False)* (temperature**2)/ target.shape[0])# 计算针对其他类别的 KL 散度损失(nckd_loss),通过从 logits 中减去一个大的值(1000.0)来忽略真实标签的影响。pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)* (temperature**2)/ target.shape[0])# 原论文中这里加入了一个 WarmUPreturn alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):# 生成一个与 logits 形状相同的全零张量,并在真实标签对应的位置设置为 1,最终返回一个布尔掩码。这个掩码用于在损失计算中关注真实类别。target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):# 生成一个与 logits 形状相同的全一张量,并在真实标签对应的位置设置为 0,最终返回一个布尔掩码。这个掩码用于在损失计算中关注其他类别。target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):# 将输入张量 t 与两个掩码结合,计算出只关注特定类别的输出。# 由于 mask1 只保留真实类别的概率,因此这个求和操作给出了每个样本的真实类别的总概率。t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rt

完整代码:

  • KD知识蒸馏
  • DKD解耦知识蒸馏

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com