您的位置:首页 > 新闻 > 热点要闻 > 教育网站制作哪个好_站长工具源码_企业网站建设的重要性_江苏seo网络

教育网站制作哪个好_站长工具源码_企业网站建设的重要性_江苏seo网络

2025/2/25 16:39:25 来源:https://blog.csdn.net/qq_28912651/article/details/145817235  浏览:    关键词:教育网站制作哪个好_站长工具源码_企业网站建设的重要性_江苏seo网络
教育网站制作哪个好_站长工具源码_企业网站建设的重要性_江苏seo网络

用 ViT 做一个简单的图像分类任务

在 CIFAR-10 数据集上进行图像分类。通过 Hugging Face 的 transformers 库,加载一个预训练的 ViT 模型,并使用 PyTorch 进行微调。通过训练模型,评估测试集上的准确性,并可视化部分预测结果

可以将此方法应用到其他数据集或任务上,只需调整数据加载部分以及输出类别数

目录

1 创建环境并安装必要的库

2 导入依赖项

3 数据准备

4 加载 ViT 模型

5 训练模型 train.py

6 测试和评估 eval.py

7 可视化结果 plot.py


1 创建环境并安装必要的库

1. Anaconda 创建环境

conda create -n ViT python=3.8

2. 激活环境

conda activate ViT

3. 安装所需的库

pip install torch torchvision transformers matplotlib

2 导入依赖项

import torch
from torch import nn
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
import matplotlib.pyplot as plt

数据准备

使用 CIFAR-10 数据集作为例子,该数据集包含10个类别的彩色图像。用以下代码加载和预处理数据集

CIFAR-10 数据集由10个类别的60000张32x32彩色图像组成,每个类别有6000张图像。有50000个训练图像和10000个测试图像
数据集分为5个 training batches 和1个test batch,每个 batch 有10000张图像。test batch 包含从每个类别中随机选择的1000张图像。training batches 包含随机顺序的剩余图像,但某些 training batches 可能包含来自一个类的图像多于另一个类。在它们之间,training batches 包含来自每个类的5000张图像
以下是数据集中的类,以及每个类中的10张随机图像:

下载的是 python 版本,代码中会自动加载下载

# 定义图像预处理操作
transform = transforms.Compose([transforms.Resize((224, 224)),   # 调整图像大小为224x224,以适配ViTtransforms.ToTensor(),           # 转换图像为Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)# 检查标签的最小值和最大值
for images, labels in train_loader:print(labels.min(), labels.max())  # 确保标签值在0到9之间break

加载 ViT 模型

加载预训练的 ViT 模型有多种方法,可以参考之前的笔记文章——ViT 相关开源项目

此处使用 Hugging Face 的transformers库加载预训练的ViT模型

更具体而言,使用 ViTForImageClassification 模型,它已预训练并适合图像分类任务

# 加载预训练的ViT模型
model = ViTForImageClassification.from_pretrained('/home/yejiangchen/Desktop/Codes/ViT/config/')
# CIFAR-10有10个类别
model.classifier = nn.Linear(model.config.hidden_size, 10)  # 假设分类层的输出为10个类别
model = model.cuda()  # 如果有GPU,转移到GPU# 确保分类层已经正确初始化
print(model.classifier)  # 打印分类层以验证# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)# 设置调试模式来帮助调试CUDA错误
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"# 创建保存模型的文件夹
model_save_path = './models/'
if not os.path.exists(model_save_path):os.makedirs(model_save_path)

如果连接 Huggingface 超时,报错:

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like google/vit-base-patch16-224-in21k is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

解决方法就是登上 huggingface,把 config.json、preprocessor_config.json、pytorch_model.bin下载到本地

例如存在 config 文件夹中:

然后在调用模型时候采用如下本地加载的方式

model = ViTForImageClassification.from_pretrained('/home/yejiangchen/Desktop/Codes/ViT/config/')

训练模型 train.py

为了训练 ViT 模型,需要定义损失函数和优化器。此处使用交叉熵损失和 Adam 优化器

# 训练模型
epochs = 3  # 设置训练的epoch数量
for epoch in range(epochs):model.train()running_loss = 0.0correct = 0total = 0for images, labels in train_loader:images, labels = images.cuda(), labels.cuda()# 前向传播outputs = model(images).logitsloss = criterion(outputs, labels)# 后向传播optimizer.zero_grad()loss.backward()optimizer.step()# 统计running_loss += loss.item()_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 打印每个epoch的损失和准确度print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')# 每个epoch后保存模型model_filename = f'{model_save_path}vit_model_epoch_{epoch+1}.pth'torch.save(model.state_dict(), model_filename)print(f'Model saved to {model_filename}')

训练结果如下:

得到模型的权重参数文件:

6 测试和评估 eval.py

在测试阶段,需要加载训练好的模型,并在测试集上评估模型的性能

使用评估模式 model.eval() 来禁用训练过程中的某些操作(如 dropout)

import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np
import torch.nn as nn  # 这里导入 nn 模块# 定义图像预处理操作
transform = transforms.Compose([transforms.Resize((224, 224)),   # 调整图像大小为224x224,以适配ViTtransforms.ToTensor(),           # 转换图像为Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])# 加载CIFAR-10测试集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)# 加载训练后的模型
model = ViTForImageClassification.from_pretrained('/home/yejiangchen/Desktop/Codes/ViT/config/')
model.classifier = nn.Linear(model.config.hidden_size, 10)  # CIFAR-10有10个类别
model.load_state_dict(torch.load('./models/vit_model_epoch_3.pth'))  # 加载训练好的模型
model = model.cuda()  # 使用GPU# 将模型设置为评估模式
model.eval()# 记录预测结果和标签
all_labels = []
all_preds = []with torch.no_grad():  # 在评估阶段不计算梯度for images, labels in test_loader:images, labels = images.cuda(), labels.cuda()# 前向传播outputs = model(images).logits_, predicted = torch.max(outputs, 1)# 记录标签和预测all_labels.extend(labels.cpu().numpy())all_preds.extend(predicted.cpu().numpy())# 绘制混淆矩阵
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()# 随机显示一些预测结果
import random
for _ in range(5):idx = random.randint(0, len(test_dataset) - 1)image, label = test_dataset[idx]image = image.unsqueeze(0).cuda()output = model(image).logits_, predicted = torch.max(output, 1)plt.imshow(image.squeeze().cpu().permute(1, 2, 0))plt.title(f'True: {test_dataset.classes[label]} | Predicted: {test_dataset.classes[predicted]}')plt.show()

运行结果如下:

模型已经成功完成了评估,输出了测试集上的损失(Test Loss: 0.1021)和准确率(Test Accuracy: 97.09%)。这表明模型在测试集上的表现非常好,具有较高的准确率

测试损失(Test Loss):表示模型在测试集上的损失函数值,通常损失越低表示模型越优秀

测试准确率(Test Accuracy):模型在测试集上正确分类的样本占所有样本的比例,97.09% 表示模型能够正确分类绝大部分测试集样本

可视化结果 plot.py

为了更好地理解模型的性能,将测试结果可视化。通常绘制混淆矩阵预测样本

  • 混淆矩阵:使用 sklearn.metrics.confusion_matrix 生成混淆矩阵,并通过 seaborn 的 heatmap 绘制热图。混淆矩阵显示了真实标签与预测标签之间的关系,帮助了解哪些类别易混淆
  • 预测样本:随机选择几张图像,并展示其真实标签与模型预测标签,以便直观评估模型性能

安装额外的库:

pip install scikit-learn
pip install seaborn

运行结果如下:

每行表示真实标签,每列表示模型的预测结果,矩阵中的数字显示了模型预测的数量

混淆矩阵分析:

  • 对角线上的数值(如 airplane 类的986)表示模型正确预测的数量,数字越大,模型对该类别的预测越准确
  • 非对角线上的数值表示误分类的情况。例如,bird 类被错误地预测为其他类别的次数。通过混淆矩阵,可以发现哪些类别之间容易混淆,进而进行优化

最后可以看到一个简单的项目的几个文件:

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com