您的位置:首页 > 文旅 > 旅游 > DETR代码解读

DETR代码解读

2024/12/22 13:43:30 来源:https://blog.csdn.net/lishanlu136/article/details/140517942  浏览:    关键词:DETR代码解读

DETR检测算法论文解读在这里


文章目录

  • 1、定义模型
    • 1.1 定义backbone
    • 1.2 定义Transformer
    • 1.3 定义DETR结构
  • 2、定义数据集
  • 3、训练


首先clone代码到本地
git clone https://github.com/facebookresearch/detr.git
代码结构如下图所示:
在这里插入图片描述
代码整体结构比较清晰,工程不是很复杂。下面从训练文件main.py为入口,一步一步解读DETR的工程。

通过查看main.py中的main函数,可以将DETR的训练过程分为:

定义模型 -> 定义优化器、学习率下降策略 -> 定义数据集,生成Dataloader -> 进入训练过程并更新参数 -> 保存模型 -> 评估模型性能

在这里插入图片描述


1、定义模型

下面看看build_model函数的实现,它定义在models/init.py文件中,调用的是models/detr.py文件中的build函数,来看看是怎么实现的:
在这里插入图片描述
可以看到,build函数返回定义的DETR模型,而且还返回集合损失评估类对象以及为了后面评估模型性能方便,将预测结果转换成coco api的形式。

1.1 定义backbone

build_backbone函数定义在models/backbone.py文件下,该函数包含定义位置embedding和backbone结构,如下:
在这里插入图片描述
backbone结构定义就在models/backbone.py文件中,直接调用torchvision自带的resnet模型定义,如下:
在这里插入图片描述
位置embedding的函数定义在models/position_encoding.py中,有两种方式可选,一种是固定的sine方式,一种是自动学习方式,如下:
在这里插入图片描述
两种方式的实现也在models/position_encoding.py文件中。

1.2 定义Transformer

Transformer定义在models/transformer.py文件中,由Transformer类实现,如下图,主要包含encoder和decoder,分别由TransformerEncoder类和TransformerDecoder类定义,这两个类都是由encoder_layer(block)和decoder_layer(block)的堆叠而成,所以直接看TransformerEncoderLayer类和TransformerDecoderLayer类即可,如下:
在这里插入图片描述
在这里插入图片描述
可以看到该类定义了两种前向传播方式,一个是layer_norm在操作之前做,一个是layer_norm在操作之后做。总体步骤都是首先融合特征图(src)与位置编码,然后做多头注意力,最后经过两个全连接层操作。
同理对于TransformerDecoderLayer类也一样
在这里插入图片描述

1.3 定义DETR结构

对于DETR类,就是直接利用backbone和transformer构建起来的,如下:
在这里插入图片描述
红框处理的DETR的核心,图像经过backbone得到特征图和位置编码,特征图经过卷积降维和位置编码一起输入transformer中,最后经class_embed和bbox_embed做分类和bounding box预测。


2、定义数据集

从datasets/_init_.py文件的build_dataset函数可以看出,代码定义了两种数据读取方式,即“coco”和“coco_panoptic”,以常用的“coco”数据读取方式为例,进入到datasets/coco.py文件中,主要是CoCoDetection类负责读取:
在这里插入图片描述
再进入里面一层,实际是ConvertCoCoPolysToMask类负责读取,并且通过return_masks参数切换是否用于分割训练用的标注方式。
在这里插入图片描述

数据增强
在读取了图像和对应标注后,就做了数据增强,调用了make_coco_transforms函数对数据集做了水平翻转,随机resize,随机crop,随机scale等操作
在这里插入图片描述


3、训练

在这里插入图片描述
训练的重要代码都在engine.py文件中的train_one_epoch函数中,如下:
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com