Center Loss
是一种用于增强深度神经网络特征分布的损失函数,广泛应用于人脸识别等分类任务。它的目的是在优化过程中最小化每个类别的特征中心和该类别的样本特征之间的距离,从而增强类内的紧凑性(intra-class compactness),在特征空间中更好地分离不同类别。
1. 背景和问题
在深度学习中的图像分类任务中,模型通常使用交叉熵(Cross Entropy)或 Softmax Loss 来优化类别分布。这类损失主要关注不同类别的可分性,即让网络学习到在输出层上区分不同类别。
然而,仅靠交叉熵无法确保同一类别的样本在特征空间中相对集中,即类内紧凑性不足。例如,在人脸识别或细粒度分类中,同一个人的照片应该在特征空间中较为集中,而不同人的照片应该较远。因此,中心损失(Center Loss)被提出,用于在特征空间中拉近同类样本的分布,从而提升类内的紧凑性。
2. Center Loss 的思想和公式
Center Loss 的核心思想是让每个类别的样本特征靠近一个中心点。我们给每一个类别设定一个特征中心 ck,让属于该类别的所有样本尽量靠近 ck。从而得到的损失函数定义如下:
其中:
- N 是批次的样本数。
- fi 表示第 i 个样本通过网络得到的特征向量。
- yi 表示第 i 个样本的真实类别标签。
- 表示样本 i 所属类别 yi 的中心。
这个损失函数的含义是:让每个样本的特征尽量靠近其所属类别的中心,以实现类内紧凑性。
3. 为什么需要 Center Loss
Softmax Loss 只能通过让不同类别在输出空间中分离来实现分类效果,但不能控制同类别样本之间的距离。Center Loss 弥补了这一不足,让同一类别的特征聚集在一个点附近。这种特征聚集可以帮助模型在验证或测试阶段,对相同类别的样本有更好的辨识和区分能力。
4. Center Loss 与 Softmax Loss 的结合
为了同时实现类内紧凑性和类间分离性,Center Loss 和 Softmax Loss 一般会结合使用。联合损失函数通常定义为:
其中:
- 是交叉熵损失。
- λ 是一个超参数,控制 Center Loss 在总损失中的权重。
通过调整 λ 的大小,我们可以平衡类内紧凑性和类间分离性。通常来说,λ 会取一个较小的值,比如 0.1~0.5,来避免 Center Loss 过度影响模型学习。
5. 特征中心的更新策略
每个类别的特征中心并不是直接优化得到的,而是随着训练过程逐步更新。更新公式为:
其中:
- α 是学习率,通常设置为一个较小的值(例如 0.5)。
- 是样本特征和其类别中心的差距。
通过逐步更新,每次仅会根据当前样本的特征向量 fi 对其类别中心 做出细微调整,从而保证中心更新的稳定性,避免因批次样本的波动导致的过度更新。
6. Center Loss 的实现步骤
我们可以使用 PyTorch 来实现 Center Loss,以下是详细的实现步骤。
步骤 1:定义 Center Loss 类
import torch
import torch.nn as nnclass CenterLoss(nn.Module):def __init__(self, num_classes, feat_dim, lambda_c=0.5):super(CenterLoss, self).__init__()self.num_classes = num_classes # 类别数self.feat_dim = feat_dim # 特征维度self.lambda_c = lambda_c # 平衡系数# 初始化每个类别的特征中心self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))def forward(self, x, labels):"""x: 当前批次样本的特征向量 (batch_size, feat_dim)labels: 当前批次样本的类别标签 (batch_size)"""batch_size = x.size(0)# 取出当前批次样本对应类别的中心centers_batch = self.centers.index_select(0, labels)# 计算 Center Losscenter_loss = self.lambda_c * 0.5 * torch.sum((x - centers_batch) ** 2) / batch_sizereturn center_loss
在 CenterLoss
类中:
- 初始化每个类别的特征中心
self.centers
。nn.Parameter
表示这些中心会参与模型参数的更新。 forward
方法中,通过index_select
提取当前批次中每个样本所属类别的中心。- 计算当前批次样本的 Center Loss。
步骤 2:在训练过程中联合使用 Center Loss 和 Softmax Loss
下面的代码展示如何将 Center Loss 和 Softmax Loss 联合使用,并在训练过程中更新模型
# 定义损失函数
criterion_softmax = nn.CrossEntropyLoss()
center_loss = CenterLoss(num_classes=10, feat_dim=128, lambda_c=0.5)# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
center_optimizer = torch.optim.SGD(center_loss.parameters(), lr=0.01) # 更新特征中心for inputs, labels in dataloader:# 前向传播features, outputs = model(inputs) # features 是特征, outputs 是分类输出# 计算 Softmax Lossloss_softmax = criterion_softmax(outputs, labels)# 计算 Center Lossloss_center = center_loss(features, labels)# 总损失loss = loss_softmax + loss_center# 反向传播optimizer.zero_grad()center_optimizer.zero_grad() # 清除中心的梯度loss.backward()optimizer.step()center_optimizer.step() # 更新中心
在这段代码中:
features
是模型中间层的特征输出,outputs
是最后的分类层输出。loss_softmax
和loss_center
分别计算 Softmax Loss 和 Center Loss。center_optimizer
专门用于更新CenterLoss
中的中心参数,以确保特征中心随着训练不断调整。
7. 总结
- 平衡系数 λ:Center Loss 是一个辅助损失,λ 的设置通常为较小的值(如 0.1~0.5),以确保 Center Loss 不会主导优化过程。
- 类内紧凑性和类间分离性:Center Loss 通过缩小同类样本的特征距离,确保类内紧凑性;而 Softmax Loss 则确保不同类间的分离性。
- 特征中心的更新速率:更新速率 α 应取较小值,以避免类别中心在每个批次间变化太大,导致不稳定。
通过结合 Center Loss 和 Softmax Loss,模型能够在特征空间中实现更清晰的类别划分和特征分布结构。这种方法在精细识别任务(如人脸识别)中有显著的效果。