您的位置:首页 > 健康 > 养生 > PyTorch(一)模型训练过程

PyTorch(一)模型训练过程

2024/12/23 16:41:28 来源:https://blog.csdn.net/qq_45031509/article/details/139856171  浏览:    关键词:PyTorch(一)模型训练过程

PyTorch(一)模型训练过程

#c 总结 实践总结

该实践从「数据处理」开始到最后利用训练好的「模型」预测,感受到了整个模型的训练过程。其中也有部分知识点,例如定义神经网络,只是初步的模仿,有一个比较浅的认识,还需要继续学习原理。

整个流程:
「准备数据」,「创建数据加载器」,「选择训练设备」,「定义神经网络」,「定义损失函数和优化器」,「定义训练和测试函数」,「迭代训练」,「保存模型」,「加载模型」,「模型预测」。

相关知识点:
1.Dataset与DataLoader
2.迭代器
3.模型定义
4.损失函数与优化器
5.模型训练与测试
6.模型保存,加载与预测

1 数据处理

#d Dataset与DataLoader

在处理数据时,PyTorch有两个基本的原语来与数据交互:torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 用于存储样本以及它们相应的标签,而 DataLoader 围绕 Dataset 封装了一个「迭代器」。
Dataset 类通常用来定义数据集,它包含了数据和标签。而 DataLoader 类则是用来批量加载数据集,支持自动加载、打乱数据、多线程加载等功能,使得数据的加载更加高效和灵活。

#e 导入库 Dataset与DataLoader

import torch
from torch import nn# 神经网络模块
from torch.utils.data import DataLoader# 数据加载器
from torchvision import datasets# 数据集
from torchvision.transforms import ToTensor# 图像转换为张量

#c 补充 特定领域库 Dataset与DataLoader

PyTorch 提供了特定领域的库,比如 TorchTextTorchVisionTorchAudio,它们都包含了Datasettorchvision.datasets 模块包含了许多现实世界视觉数据集的 Dataset 对象,例如 CIFAR、COCO。每个 TorchVisionDataset都包括两个参数:transformtarget_transform,它们分别用来修改样本和标签。

#d 迭代器(Iterable)

迭代器(Iterable)是一种允许程序员遍历一个容器(特别是列表等序列类型)的对象。在Python中,迭代器遵循迭代协议,即它们实现了__iter__()方法,该方法返回一个迭代器对象本身,这个对象还需要实现__next__()方法,该方法在每次迭代时返回容器中的下一个项目。通过提供一种统一、高效、按需处理数据的方式,极大地简化了数据遍历和处理的复杂性。

「迭代器」解决的问题:

  1. 统一的遍历接口:迭代器提供了一种统一的方法来遍历各种类型的数据容器(如列表、元组、字典等),而不需要知道容器的内部结构。
  2. 内存效率:迭代器允许按需遍历元素,而不是一次性将所有元素加载到内存中。这对于遍历大数据集特别有用,因为它可以显著减少程序的内存使用。
  3. 惰性计算:迭代器支持惰性计算,这意味着数据元素是在需要时才被计算和返回,而不是在迭代器创建时。这可以提高计算效率,特别是在处理复杂或无限的数据序列时。

没有「迭代器」的影响:

  1. 遍历复杂性增加:没有迭代器,程序员需要为不同类型的数据结构编写不同的遍历代码,这不仅增加了开发的复杂性,也降低了代码的可重用性。
  2. 内存效率降低:在处理大型数据集时,可能需要一次性将所有数据加载到内存中,这会导致显著的内存消耗,甚至可能导致内存不足的错误。
  3. 减少惰性计算的机会:没有迭代器机制,很难实现按需计算数据元素的逻辑,这可能导致不必要的计算开销,特别是在只需要数据集一小部分或者在数据集很大时。

#e 吃自助餐 迭代器

想象一下你在一家餐厅吃自助餐。自助餐提供了一个装满不同菜肴的长桌子,你拿着一个盘子,从一端开始,挨个检查每种菜肴,决定是否将其加入你的盘子。在这个过程中,你(顾客)就像一个迭代器,而长桌子上的菜肴就像是一个可迭代的容器。你一次检查一个菜肴,直到遍历完所有的菜肴,或者你的盘子满了为止。

#e 迭代访问列表 迭代器

假设我们有一个列表(List)numbers = [1, 2, 3, 4, 5],使用iter(numbers)创建了一个迭代器,它能够遍历列表numbers中的每个元素。使用next(iterator)可以获取容器中的下一个元素。当所有元素都被遍历完毕时,next()会抛出一个StopIteration异常,表示没有更多元素可以访问,这时我们结束循环。

numbers = [1, 2, 3, 4, 5]  # 可迭代的容器
iterator = iter(numbers)  # 创建迭代器while True:try:# 使用next()获取下一个元素number = next(iterator)print(number)except StopIteration:# 如果所有元素都遍历完毕,则结束循环break

#c 关联 相关概念

「迭代器」影响的「概念」:

  1. 可迭代对象(Iterable):任何实现了__iter__()方法的对象都是可迭代的,该方法需要返回一个迭代器对象。迭代器本身也是可迭代的,因为它实现了__iter__()方法,并返回自身。

  2. 生成器(Generator):生成器是一种特殊类型的迭代器,它使用函数加上yield语句来实现,无需手动实现__iter__()__next__()方法。生成器简化了迭代器的创建过程,直接受到了迭代器概念的启发。

  3. 循环(Loops):例如for循环和while循环,在Python中,for循环内部实际上使用迭代器来遍历可迭代对象。

  4. 函数式编程工具:如map()filter()reduce()等函数,它们接受一个函数和一个可迭代对象作为输入,内部通过迭代器遍历可迭代对象。

影响「迭代器」的概念:

  1. 面向对象编程(OOP):迭代器模式是面向对象设计模式的一部分,要求对象实现特定的接口(如Python中的__iter__()__next__()方法)。面向对象的概念提供了迭代器实现的框架。

  2. 惰性计算(Lazy Evaluation):惰性计算是指仅在真正需要计算结果时才进行计算。迭代器天然支持惰性计算,因为它们一次只处理集合中的一个元素。

  3. 函数式编程(Functional Programming):函数式编程强调使用函数来处理数据。迭代器与函数式编程紧密相关,因为迭代器提供了一种遍历和处理数据集合的方法,而不改变数据本身,这与函数式编程的不可变性原则相吻合。

#c 说明 数据集的选择

本次实践使用 FashionMNIST 数据集。该数据集是一个用于衣物识别的数据集,由Zalando(一家欧洲的在线时尚零售商)提供。它被设计为原始MNIST数据集的直接替代品,用于在机器学习和计算机视觉领域的基准测试中。FashionMNIST包含了10个类别的衣物图片,每个类别有7000张图片,整个数据集分为60000张训练图片和10000张测试图片。每张图片都是28x28像素的灰度图。这些类别包括:

  1. T-shirt/top(T恤/上衣)
  2. Trouser(裤子)
  3. Pullover(套衫)
  4. Dress(连衣裙)
  5. Coat(外套)
  6. Sandal(凉鞋)
  7. Shirt(衬衫)
  8. Sneaker(运动鞋)
  9. Bag(包)
  10. Ankle boot(短靴)

#e 下载数据集 数据集的选择

如果是自行搜集数据,比如利用爬虫获取自己想要的数据,获取的数据需要进行「数据处理」,例如「删除不符合数据」,「统一数据格式」,「去重」等方式。这里下载的数据已经是符合训练的数据格式,所以不需要进行对应的数据处理的环节。

# 下载训练数据集
train_data = datasets.FashionMNIST(root="data",  # 数据存储的路径train=True,   # 指定下载的是训练数据集download=True,  # 如果数据不存在,则通过网络下载transform=ToTensor()  # 将图片转换为Tensor
)# 下载测试数据集
test_data = datasets.FashionMNIST(root="data",  # 数据存储的路径train=False,  # 指定下载的是测试数据集download=True,  # 如果数据不存在,则通过网络下载transform=ToTensor()  # 将图片转换为Tensor
)

#d 数据加载

Dataset 作为参数传递给 DataLoaderDataLoaderdataset封装一个可迭代对象,并且支持自动批处理、采样、多进程数据加载等。

#e 加载代码 数据加载

在这里,定义了一个批量大小为64,即 dataloader 可迭代对象中的每个元素将返回一个包含64个特征和标签的批次。

batch_size = 64# 批大小# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
#将dataset作为参数传入DataLoader,DataLoader会自动将数据分批,打乱数据,将数据加载到内存中
test_dataloader = DataLoader(test_data, batch_size=batch_size)for x,y in test_dataloader:print(f"Shape of x [N, C, H, W]: {x.shape}")#x.shape是一个4维张量,第一个维度是批大小,第二个维度是通道数,第三和第四维度是图像的高度和宽度print(f"Shape of y: {y.shape}, {y.dtype}")'''Shape of x [N, C, H, W]: torch.Size([64, 1, 28, 28])Shape of y: torch.Size([64]), torch.int64'''break

2 创建模型

#d 定义模型

在PyTorch中定义神经网络,需创建一个继承自nn.Module的类,并在__init__函数中定义神经网络的层,在forward函数中定义数据在神经网络中的传播路径。为了加速神经网络的训练,可以使用GPU或者MPS来训练模型。

#e 定义代码 定义模型

#使用cpu,gpu,mps的设备来训练模型
device =("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")
#Using cuda deviceclass NeuralNetwork(nn.Module):def __init__(self):#定义神经网络的层super().__init__()#调用父类的构造函数self.flatten = nn.Flatten()#将28*28的图像展平为784的向量self.linear_relu_stack = nn.Sequential(#定义一个包含三个线性层的神经网络nn.Linear(28*28,512),#输入层nn.ReLU(),#激活函数nn.Linear(512,512),#隐藏层nn.ReLU(),#激活函数nn.Linear(512,10),#输出层)def forward(self,x):#定义数据在神经网络中的传播路径x = self.flatten(x)#将图像展平logits = self.linear_relu_stack(x)#将展平后的图像传入神经网络return logits#返回输出
model = NeuralNetwork().to(device)#将模型加载到设备上
print(model)
'''
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)
'''

3 优化模型参数

#d 定义训练参数

  1. 在训练模型之前,需要定义「损失函数(loss function)」[ 和「优化器(optimizer)」。概念解释(5 相关概念)

  2. 在单个训练循环中,模型会对分批提供它的「训练数据集」进行「预测」并通过「反向传播算法」预测误差以调整模型的参数。

  3. 检查模型在测试数据集上的性能,以确保它在学习.

  4. 训练过程在多个迭代(周期)中进行。在每个周期中,模型学习参数以做出更好的预测。在每个周期打印模型的准确率和损失,希望看到准确率随着每个周期的增加而提高,损失随着每个周期的减少。

#e 损失函数与优化器

loss_fn = nn.CrossEntropyLoss()#使用交叉熵损失函数
#使用随机梯度下降优化器,model.parameters()返回模型的参数,lr=1e-3是学习率
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

#e 训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)#数据集的大小model.train()#将模型设置为训练模式for batch, (X, y) in enumerate(dataloader):#遍历数据集X, y = X.to(device), y.to(device)#将数据加载到设备上# 计算预测误差pred = model(X)#对输入的数据进行预测loss = loss_fn(pred, y)#计算损失,差异越小,模型预测的越准确# 反向传播loss.backward()#反向传播算法optimizer.step()#优化器更新模型参数optimizer.zero_grad()#梯度清零if batch % 100 == 0:#每100个批次打印一次loss, current = loss.item(), (batch+1) * len(X)#打印损失和当前的批次的数据量print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")'''{loss:>7f}:表示损失值以浮点数形式打印,总宽度为7位,右对齐。{current:>5d}:表示当前处理的总数据量以整数形式打印,总宽度为5位,右对齐。{size:>5d}:表示整个数据集的大小以整数形式打印,总宽度为5位,右对齐'''

#e 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)#数据集的大小num_batches = len(dataloader)#批次的数量model.eval()#将模型设置为评估模式test_loss, correct = 0, 0#初始化损失和正确的数量with torch.no_grad():#关闭梯度计算for X, y in dataloader:X, y = X.to(device), y.to(device)#将数据加载到设备上pred = model(X)#对输入的数据进行预测test_loss += loss_fn(pred, y).item()#计算损失correct += (pred.argmax(1) == y).type(torch.float).sum().item()#计算正确的数量'''pred.argmax(1)找出每个预测中概率最高的类别的索引,== y判断这些索引是否与真实标签相等。结果是一个布尔Tensor,通过.type(torch.float)转换为浮点数Tensor,然后使用.sum().item()计算并累加正确预测的总数。'''test_loss /= num_batches#计算平均损失correct /= size#计算正确率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

#e 迭代训练

epochs = 5#迭代次数
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)#训练模型test(test_dataloader, model, loss_fn)#测试模型
print("Done!")#训练完成
'''
运行结果
Epoch 1
-------------------------------
loss: 2.304268  [   64/60000]
loss: 2.284021  [ 6464/60000]
loss: 2.263621  [12864/60000]
loss: 2.259448  [19264/60000]
loss: 2.231920  [25664/60000]
loss: 2.221592  [32064/60000]
loss: 2.215944  [38464/60000]
loss: 2.191191  [44864/60000]
loss: 2.177027  [51264/60000]
loss: 2.141848  [57664/60000]
Test Error: Accuracy: 58.7%, Avg loss: 2.137664Epoch 2
-------------------------------
loss: 2.147467  [   64/60000]
loss: 2.139907  [ 6464/60000]
loss: 2.077062  [12864/60000]
loss: 2.094236  [19264/60000]
loss: 2.030329  [25664/60000]
loss: 1.982215  [32064/60000]
loss: 1.997371  [38464/60000]
loss: 1.923110  [44864/60000]
loss: 1.913458  [51264/60000]
loss: 1.835431  [57664/60000]
Test Error: Accuracy: 61.3%, Avg loss: 1.839774
'''

4 模型的保存

#d 保存方式

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

#e 实现代码 保存方式

torch.save(model.state_dict(), "./model.pth")
print("Saved PyTorch Model State to ./model.pth")
'''
Saved PyTorch Model State to ./model.pth
'''

5 模型加载与预测

#d 加载流程

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

#e 加载代码 加载流程

model = NeuralNetwork().to(device)#创建模型,to(device)将模型加载到设备上
model.load_state_dict(torch.load("./model.pth"))#加载模型

#e 预测代码

#利用模型进行预测
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
model.eval()#将模型设置为评估模式
x, y = test_data[0][0], test_data[0][1]#获取测试数据
with torch.no_grad():#关闭梯度计算pred = model(x.to(device))#对输入的数据进行预测predicted, actual = classes[pred[0].argmax(0)], classes[y]#获取预测的类别和真实的类别print(f'Predicted: "{predicted}", Actual: "{actual}"')'''Predicted: "Ankle boot", Actual: "Ankle boot"'''

#c 备注 完整python文件

AI_series_learn/PyTorch/1.快速开始/basic.py at main · togetherhkl/AI_series_learn (github.com)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com