您的位置:首页 > 游戏 > 手游 > pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

2024/12/23 7:14:22 来源:https://blog.csdn.net/yueguang8/article/details/141260814  浏览:    关键词:pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

在 PyTorch 中,.pt 模型文件通常包含以下几类数据:

        模型参数:

                存储模型的权重和偏置参数。

        优化器状态:

                包含优化器的状态信息,以便在恢复训练时能够从中断的地方继续。

        训练状态:

                一些训练过程中的信息,例如当前的 epoch 数和训练进度。

        其他元数据:

                包括模型的配置、训练时使用的超参数等。

        在讲解pytorch pt(pth)文件中保存了什么内容之前,需要先了解pt在保存时保存了那些参数。

以YOLO系列pt保存代码来介绍说明:

1. 模型保存代码:

 def save_model(self):ckpt = {'epoch': self.epoch, #'best_fitness': self.best_fitness,'model': deepcopy(de_parallel(self.model)).half(),'ema': deepcopy(self.ema.ema).half(),'updates': self.ema.updates,'optimizer': self.optimizer.state_dict(),'train_args': vars(self.args),  # save as dict'date': datetime.now().isoformat(),'version': __version__}# Use dill (if exists) to serialize the lambda functions where pickle does not do thistry:import dill as pickleexcept ImportError:import pickle# Save last, best and deletetorch.save(ckpt, self.last, pickle_module=pickle)if self.best_fitness == self.fitness:torch.save(ckpt, self.best, pickle_module=pickle)if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)del ckpt

参数说明:

        'epoch': 当前的训练轮次数。

        'best_fitness': 最佳性能指标的数值。

        'model': 深拷贝(deepcopy)并将模型参数进行半精度(half)转换后的模型。

        'ema': 深拷贝并将指数移动平均模型参数进行半精度转换后的指数移动平均模型。

        'updates': 指数移动平均模型的更新次数。

        'optimizer': 优化器的状态字典(state_dict)。

        'train_args': 训练参数的字典表示,使用vars(self.args)将self.args对象转换为字典。

        'date': 当前的日期和时间,使用datetime.now().isoformat()获取。

        'version': 代码的版本号,通过__version__获取。

        其中:model中保存的模型的结构,train_args中保存训练时的一些参数(超参数)。

通过上述功能函数可以看到pytorch保存的pt文件中的内容。

补充说明:

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

        torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

2. 模型加载介绍

下面通过Debug来详解pt中的具体内容:

首先加载模型,代码如下:

import sys
import argparse
import os
import struct
import torch
pt_file = "./yolov8n.pt"
wts_file = "./yolov8n.wts"
# Initialize
device = 'cpu'
# Load model
modelAll = torch.load(pt_file, map_location=device)
model = modelAll['model'].float()  # load to FP32
#model = torch.load(pt_file, map_location=device)['model'].float()  # load to FP32anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchors')
model.to(device).eval()
with open(wts_file, 'w') as f:f.write('{}\n'.format(len(model.state_dict().keys())))for k, v in model.state_dict().items():print("key={0}, v={1}".format(k,v))vr = v.reshape(-1).cpu().numpy()f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex())f.write('\n')

 Debug结果如下所示,分别对应save_model()中保存的内容

其中model(model = modelAll['model'].float())中内容如下:

       model的类型为DetectionModel,里面包含了模型结构(model.model)以及参数信息(model.args)及构造网络时的配置参数信息(model.yaml)以及目标类别及个数、stride等信息。 

3. 模型权重解析保存

        model.state_dict()是一个字典,键是参数的名称,值是对应的 tensor。

        其中保存着模型的权重(Weights)和偏置值(Biases)以及运行均值和方差(例如,Batch Normalization 层的 running_mean 和 running_var,用于推理时)等信息。

        权重解析保存代码如下:

with open(wts_file, 'w') as f:f.write('{}\n'.format(len(model.state_dict().keys())))for k, v in model.state_dict().items():print("key={0}, v={1}".format(k,v))vr = v.reshape(-1).cpu().numpy()f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex())f.write('\n')

代码功能介绍:

  1. 使用写模式打开一个文件 wts_file,以便保存模型的参数。
  2. 将模型参数的数量写入文件。
  3. 循环遍历每个参数的键名 k 和对应的值 v。
  4. 将参数 v 重塑为一维数组,并将其从 GPU 移动到 CPU(如果适用),然后转换为 NumPy 数组。
  5. 写入参数的名称和长度。
    for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex())

        遍历每个参数值,使用大端格式(‘>’)将其转换为浮点数并写入文件.

pt解包后保存后的文件内容如下:

上述代码可以将pt格式模型,转化为Nvidia TensorRT部署需要的文件。 

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com