您的位置:首页 > 新闻 > 资讯 > com域名查询_app一键生成平台免费软件_软文发布网站_个人接app推广单去哪里接

com域名查询_app一键生成平台免费软件_软文发布网站_个人接app推广单去哪里接

2025/2/25 3:35:22 来源:https://blog.csdn.net/qq_50645064/article/details/144999527  浏览:    关键词:com域名查询_app一键生成平台免费软件_软文发布网站_个人接app推广单去哪里接
com域名查询_app一键生成平台免费软件_软文发布网站_个人接app推广单去哪里接

 概述

这篇博客主要是找到在RT-DETR中,模型和数据集是怎么传入train_ine_epoch中进行训练的

一、train.py

二、solver/__init__.py文件

 在train.py的头文件中from src.solver import TASKS,TASKS不是文件,可以看到左侧有init.py文件。

Python中的__init__.py文件作用-CSDN博客

在init.py文件中,TASKS是一个字典类型变量,使用了 Python 的 类型注解,通过 Dict[str, BaseSolver] 表明 TASKS 是一个字典类型。

  • Dict 表示这是一个字典,其中:
    • 键(key)是 字符串类型str)。这里表示某个任务名称,如"detection"
    • 值(value)是 BaseSolver 类型。(BaseSolever是一个父类,DetSolver是其子类)

用于将任务名称(如 'detection')映射到对应的求解器类。

在train.py文件中,

solver = TASKS[cfg.yaml_cfg['task']](cfg)

cfg.yaml_cfg['task'] 用于获取任务类型,cfg.yaml_cfg 一般内容如下:

{"task": "detection","learning_rate": 0.001,"batch_size": 32
}

 (可看RT-DETR代码详解(官方pytorch版)——参数配置(1)-CSDN博客)

  • 这里就相当于cfg.yaml_cfg['task'] 返回 'detection'

  • TASKS[cfg.yaml_cfg['task']] 等价于 TASKS['detection']

  • 结果是 DetectionSolver(一个类)

  • DetectionSolver(cfg) 会调用类的构造方法(__init__ 方法),并传入参数 cfg

 三、solver/solver.py/BaseSolver

3.1 __init__初始化

def __init__(self, cfg: BaseConfig) -> None:self.cfg = cfg
  • 作用:初始化 BaseSolver 实例,接收一个配置对象 cfg,该配置对象通常包含训练所需的所有配置信息(如设备、优化器、数据加载器等)。

3.2 setup方法

def setup(self, ):'''Avoid instantiating unnecessary classes '''# 配置设备和属性cfg = self.cfgdevice = cfg.deviceself.device = deviceself.last_epoch = cfg.last_epoch# 初始化模型、损失函数、后处理器self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)self.criterion = cfg.criterion.to(device)self.postprocessor = cfg.postprocessor# 加载调优状态(如果有)if self.cfg.tuning:print(f'Tuning checkpoint from {self.cfg.tuning}')self.load_tuning_state(self.cfg.tuning)# 初始化混合精度、EMA、输出目录self.scaler = cfg.scalerself.ema = cfg.ema.to(device) if cfg.ema is not None else None self.output_dir = Path(cfg.output_dir)self.output_dir.mkdir(parents=True, exist_ok=True)

 

作用

  • 配置模型、损失函数、设备、后处理器等。
  • 支持 混合精度训练scaler)和 EMA机制(Exponential Moving Average)
  • 支持 Fine-Tuning(如果有预训练模型)。
  • 创建输出目录,用于保存模型的状态。

设计思想

  • 避免在每个方法中重复配置,使用 setup 方法集中进行初始化。
  • 支持分布式训练(dist.warp_model 和 dist.warp_loader)。

3.3 train方法

def train(self, ):self.setup()self.optimizer = self.cfg.optimizerself.lr_scheduler = self.cfg.lr_scheduler# 加载断点if self.cfg.resume:print(f'Resume checkpoint from {self.cfg.resume}')self.resume(self.cfg.resume)# 数据加载器self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, shuffle=self.cfg.train_dataloader.shuffle)self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle)

作用

  • 调用 setup() 方法完成初始化。
  • 配置优化器和学习率调度器。
  • 支持从断点恢复训练。
  • 配置训练和验证数据加载器,并支持分布式。

设计思想

  • 提供训练流程的模板,确保训练前的各种组件正确配置。

 3.4 eval方法

def eval(self, ):self.setup()self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle)if self.cfg.resume:print(f'resume from {self.cfg.resume}')self.resume(self.cfg.resume)
  • 作用
    • 配置验证数据加载器。
    • 支持从断点恢复,以便进行测试或验证。

 3.5 模型保存与加载

3.5.1 state_dict()

def state_dict(self, last_epoch):'''state dict'''state = {}state['model'] = dist.de_parallel(self.model).state_dict()state['date'] = datetime.now().isoformat()state['last_epoch'] = last_epochif self.optimizer is not None:state['optimizer'] = self.optimizer.state_dict()if self.lr_scheduler is not None:state['lr_scheduler'] = self.lr_scheduler.state_dict()if self.ema is not None:state['ema'] = self.ema.state_dict()if self.scaler is not None:state['scaler'] = self.scaler.state_dict()return state
  • 作用
    • 将模型、优化器、学习率调度器、EMA 和混合精度等状态保存为字典,以便后续加载。

3.5.2 load_state_dict()

def load_state_dict(self, state):'''load state dict'''# 加载模型、优化器、调度器等状态if getattr(self, 'last_epoch', None) and 'last_epoch' in state:self.last_epoch = state['last_epoch']print('Loading last_epoch')if getattr(self, 'model', None) and 'model' in state:if dist.is_parallel(self.model):self.model.module.load_state_dict(state['model'])else:self.model.load_state_dict(state['model'])print('Loading model.state_dict')if getattr(self, 'ema', None) and 'ema' in state:self.ema.load_state_dict(state['ema'])print('Loading ema.state_dict')if getattr(self, 'optimizer', None) and 'optimizer' in state:self.optimizer.load_state_dict(state['optimizer'])print('Loading optimizer.state_dict')if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:self.lr_scheduler.load_state_dict(state['lr_scheduler'])print('Loading lr_scheduler.state_dict')if getattr(self, 'scaler', None) and 'scaler' in state:self.scaler.load_state_dict(state['scaler'])print('Loading scaler.state_dict')
  • 作用
    • 从保存的状态字典中恢复模型和训练过程。
    • 支持分布式训练环境。

 3.5.3 save 和 resume

def save(self, path):'''save state'''state = self.state_dict()dist.save_on_master(state, path)def resume(self, path):'''load resume'''state = torch.load(path, map_location='cpu')self.load_state_dict(state)
  • 作用
    • 保存和加载模型的断点状态,支持分布式的保存。

3.6 load_tuning_state方法

def load_tuning_state(self, path):"""only load model for tuning and skip missed/dismatched keys"""if 'http' in path:state = torch.hub.load_state_dict_from_url(path, map_location='cpu')else:state = torch.load(path, map_location='cpu')module = dist.de_parallel(self.model)if 'ema' in state:stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])else:stat, infos = self._matched_state(module.state_dict(), state['model'])module.load_state_dict(stat, strict=False)print(f'Load model.state_dict, {infos}')
  • 作用
    • 加载调优状态时,跳过不匹配或缺失的权重。
    • 用于 Fine-Tuning 场景。

 四、solver/det_solver.py/DetSolver

TASKS中detection任务对应的DetSolver类,它继承自之前定义的 BaseSolver 类,并实现了具体的训练(fit 方法)和验证(val 方法)逻辑

4.1 fit方法

fit 方法定义了目标检测模型的训练流程,包括训练、验证、保存模型状态等。

def fit(self):print("Start training")self.train()  # 初始化训练配置...# 开始训练循环for epoch in range(self.last_epoch + 1, args.epoches):...# 训练单个 epochtrain_stats = train_one_epoch(self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)# 更新学习率调度器self.lr_scheduler.step()# 保存模型状态checkpoint_paths = [self.output_dir / 'checkpoint.pth']...for checkpoint_path in checkpoint_paths:dist.save_on_master(self.state_dict(epoch), checkpoint_path)# 验证模型module = self.ema.module if self.ema else self.modeltest_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir)...# 打印和保存日志...

这部分内容可以对应DETR系列中main.py文件中的部分:

 找到这部分内容就好了,知道模型和数据集是怎么传入代码中进行训练的,后面就可以根据传入的模型和数据找到对应的初始位置然后进行修改

详细解析
(1) 准备工作
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
best_stat = {'epoch': -1, }
  • 调用 self.train() 进行初始化,包括配置优化器、学习率调度器等。
  • 计算模型的参数总数(n_parameters),便于日志记录。
  • 获取验证数据集的 COCO 接口对象(base_ds),用于后续评估。
(2) 训练循环
for epoch in range(self.last_epoch + 1, args.epoches):if dist.is_dist_available_and_initialized():self.train_dataloader.sampler.set_epoch(epoch)train_stats = train_one_epoch(self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
  • 分布式支持:如果使用分布式训练(dist),为每个 epoch 设置随机采样器的种子。
  • 单次 epoch 训练
    • 调用 train_one_epoch 方法,完成模型在一个 epoch 内的训练(包括前向传播、计算损失、反向传播、权重更新等)。
    • 支持 梯度裁剪clip_max_norm)、EMA 模型更新 和 混合精度训练
(3) 更新学习率
self.lr_scheduler.step()
  • 调用学习率调度器,调整优化器的学习率。
(4) 保存模型状态
if self.output_dir:checkpoint_paths = [self.output_dir / 'checkpoint.pth']if (epoch + 1) % args.checkpoint_step == 0:checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')for checkpoint_path in checkpoint_paths:dist.save_on_master(self.state_dict(epoch), checkpoint_path)
  • 定期保存模型的断点状态(包括模型权重、优化器状态等)。
  • 默认保存为 checkpoint.pth,并在指定训练轮数(如每 100 轮)的基础上保存额外的检查点。
(5) 验证模型
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)

  • 使用 evaluate 方法对模型进行验证,测量其在验证集上的性能。
  • 如果启用了 EMA 模型,则使用 EMA 模型进行验证。
(6) 更新最佳结果(best_stat
for k in test_stats.keys():if k in best_stat:best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']best_stat[k] = max(best_stat[k], test_stats[k][0])else:best_stat['epoch'] = epochbest_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)

  • 将当前 epoch 的验证结果与 best_stat 进行对比,更新最佳性能记录。
(7) 记录日志
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},**{f'test_{k}': v for k, v in test_stats.items()},'epoch': epoch,'n_parameters': n_parameters}if self.output_dir and dist.is_main_process():with (self.output_dir / "log.txt").open("a") as f:f.write(json.dumps(log_stats) + "\n")

  • 记录训练和验证的统计信息,并将其保存为 JSON 格式的日志文件。
(8) 保存评估结果
if coco_evaluator is not None:(self.output_dir / 'eval').mkdir(exist_ok=True)if "bbox" in coco_evaluator.coco_eval:filenames = ['latest.pth']if epoch % 50 == 0:filenames.append(f'{epoch:03}.pth')for name in filenames:torch.save(coco_evaluator.coco_eval["bbox"].eval,self.output_dir / "eval" / name)

  • 保存 COCO 评估结果(如 bbox 的评估指标)。

 4.2 val方法

val 方法定义了模型的验证流程,主要用于评估模型在验证集上的性能。

def val(self):self.eval()  # 初始化验证配置base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)module = self.ema.module if self.ema else self.modeltest_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,self.val_dataloader, base_ds, self.device, self.output_dir)if self.output_dir:dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")

五、solver/det_engine.py/train_one_epoch

det_engine.py就和detr系列中的engine.py文件内容一样了

 

 

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com