使用ModelEmaV2优化MNIST分类模型
在深度学习模型的训练过程中,参数波动可能会导致模型在测试集上的性能不稳定。为了解决这个问题,可以使用指数移动平均(EMA)技术来平滑参数的更新,从而获得更稳定的模型。本文将介绍如何在MNIST数据集上使用ModelEmaV2来优化分类模型,并分析其效果。
实验背景
MNIST数据集是一个经典的手写数字识别数据集,包含60,000张训练图像和10,000张测试图像。我们的目标是训练一个简单的神经网络模型来分类这些手写数字,并使用EMA技术来优化模型参数。
模型定义与EMA实现
首先,我们定义一个简单的全连接神经网络模型,并实现ModelEmaV2来进行EMA参数更新。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy# 定义简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(28*28, 10)def forward(self, x):x = x.view(-1, 28*28)x = self.fc(x)return x# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义 EMA 模型
class ModelEmaV2(nn.Module):def __init__(self, model, decay=0.99, device='cpu'):super(ModelEmaV2, self).__init__()self.ema_model = copy.deepcopy(model).to(device)self.ema_model.eval()self.decay = decayself.device = devicedef update(self, model):with torch.no_grad():model_params = dict(model.named_parameters())ema_params = dict(self.ema_model.named_parameters())for k in model_params.keys():ema_params[k].mul_(self.decay).add_(model_params[k], alpha=1 - self.decay)def forward(self, x):return self.ema_model(x)
数据加载与预处理
我们使用torchvision
库来加载和预处理MNIST数据集。
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
训练与评估
我们进行4个epoch的训练,并在每个epoch结束后评估模型和EMA模型的准确率。
# 训练和评估
num_epochs = 4
results = []for epoch in range(num_epochs):model.train()for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()# 更新 EMA 模型ema_model.update(model)# 计算每个epoch的准确率model.eval()ema_model.eval()correct = 0total = 0ema_correct = 0ema_total = 0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct += (predicted == targets).sum().item()# 测试 EMA 模型ema_outputs = ema_model(inputs)_, ema_predicted = torch.max(ema_outputs.data, 1)ema_total += targets.size(0)ema_correct += (ema_predicted == targets).sum().item()normal_accuracy = 100 * correct / totalema_accuracy = 100 * ema_correct / ema_totallag = normal_accuracy - ema_accuracyresults.append({'epoch': epoch + 1,'normal_accuracy': normal_accuracy,'ema_accuracy': ema_accuracy,'lag': lag})results
实验结果分析
实验结果如下表所示:
Epoch | Normal Model Accuracy | EMA Model Accuracy | Lag |
---|---|---|---|
1 | 91.09 | 90.97 | 0.12 |
2 | 92.54 | 92.46 | 0.08 |
3 | 93.53 | 93.50 | 0.03 |
4 | 94.03 | 94.13 | -0.10 |
从结果可以看出,在训练的前几轮,EMA模型的准确率稍微滞后于正常模型,但随着训练的进行,两者的准确率逐渐接近,甚至在第四轮时,EMA模型的准确率略高于正常模型。
结论
通过实验可以看出,EMA技术在一定程度上平滑了模型参数的波动,使得模型在测试集上的表现更加稳定。尽管在训练的初期EMA模型的准确率稍有滞后,但随着训练的进行,EMA模型的表现逐渐赶上并超过了正常模型。这表明EMA技术对于提高模型的稳定性和性能具有重要作用。