您的位置:首页 > 房产 > 家装 > ema_mnist_blog

ema_mnist_blog

2025/1/15 6:51:07 来源:https://blog.csdn.net/zjh12312311/article/details/139331606  浏览:    关键词:ema_mnist_blog

使用ModelEmaV2优化MNIST分类模型

在深度学习模型的训练过程中,参数波动可能会导致模型在测试集上的性能不稳定。为了解决这个问题,可以使用指数移动平均(EMA)技术来平滑参数的更新,从而获得更稳定的模型。本文将介绍如何在MNIST数据集上使用ModelEmaV2来优化分类模型,并分析其效果。

实验背景

MNIST数据集是一个经典的手写数字识别数据集,包含60,000张训练图像和10,000张测试图像。我们的目标是训练一个简单的神经网络模型来分类这些手写数字,并使用EMA技术来优化模型参数。

模型定义与EMA实现

首先,我们定义一个简单的全连接神经网络模型,并实现ModelEmaV2来进行EMA参数更新。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy# 定义简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(28*28, 10)def forward(self, x):x = x.view(-1, 28*28)x = self.fc(x)return x# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义 EMA 模型
class ModelEmaV2(nn.Module):def __init__(self, model, decay=0.99, device='cpu'):super(ModelEmaV2, self).__init__()self.ema_model = copy.deepcopy(model).to(device)self.ema_model.eval()self.decay = decayself.device = devicedef update(self, model):with torch.no_grad():model_params = dict(model.named_parameters())ema_params = dict(self.ema_model.named_parameters())for k in model_params.keys():ema_params[k].mul_(self.decay).add_(model_params[k], alpha=1 - self.decay)def forward(self, x):return self.ema_model(x)

数据加载与预处理

我们使用torchvision库来加载和预处理MNIST数据集。

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练与评估

我们进行4个epoch的训练,并在每个epoch结束后评估模型和EMA模型的准确率。

# 训练和评估
num_epochs = 4
results = []for epoch in range(num_epochs):model.train()for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()# 更新 EMA 模型ema_model.update(model)# 计算每个epoch的准确率model.eval()ema_model.eval()correct = 0total = 0ema_correct = 0ema_total = 0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct += (predicted == targets).sum().item()# 测试 EMA 模型ema_outputs = ema_model(inputs)_, ema_predicted = torch.max(ema_outputs.data, 1)ema_total += targets.size(0)ema_correct += (ema_predicted == targets).sum().item()normal_accuracy = 100 * correct / totalema_accuracy = 100 * ema_correct / ema_totallag = normal_accuracy - ema_accuracyresults.append({'epoch': epoch + 1,'normal_accuracy': normal_accuracy,'ema_accuracy': ema_accuracy,'lag': lag})results

实验结果分析

实验结果如下表所示:

EpochNormal Model AccuracyEMA Model AccuracyLag
191.0990.970.12
292.5492.460.08
393.5393.500.03
494.0394.13-0.10

从结果可以看出,在训练的前几轮,EMA模型的准确率稍微滞后于正常模型,但随着训练的进行,两者的准确率逐渐接近,甚至在第四轮时,EMA模型的准确率略高于正常模型。

结论

通过实验可以看出,EMA技术在一定程度上平滑了模型参数的波动,使得模型在测试集上的表现更加稳定。尽管在训练的初期EMA模型的准确率稍有滞后,但随着训练的进行,EMA模型的表现逐渐赶上并超过了正常模型。这表明EMA技术对于提高模型的稳定性和性能具有重要作用。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com