DETR检测算法论文解读在这里
文章目录
- 1、定义模型
- 1.1 定义backbone
- 1.2 定义Transformer
- 1.3 定义DETR结构
- 2、定义数据集
- 3、训练
首先clone代码到本地
git clone https://github.com/facebookresearch/detr.git
代码结构如下图所示:
代码整体结构比较清晰,工程不是很复杂。下面从训练文件main.py为入口,一步一步解读DETR的工程。
通过查看main.py中的main函数,可以将DETR的训练过程分为:
定义模型 -> 定义优化器、学习率下降策略 -> 定义数据集,生成Dataloader -> 进入训练过程并更新参数 -> 保存模型 -> 评估模型性能
1、定义模型
下面看看build_model函数的实现,它定义在models/init.py文件中,调用的是models/detr.py文件中的build函数,来看看是怎么实现的:
可以看到,build函数返回定义的DETR模型,而且还返回集合损失评估类对象以及为了后面评估模型性能方便,将预测结果转换成coco api的形式。
1.1 定义backbone
build_backbone函数定义在models/backbone.py文件下,该函数包含定义位置embedding和backbone结构,如下:
backbone结构定义就在models/backbone.py文件中,直接调用torchvision自带的resnet模型定义,如下:
位置embedding的函数定义在models/position_encoding.py中,有两种方式可选,一种是固定的sine方式,一种是自动学习方式,如下:
两种方式的实现也在models/position_encoding.py文件中。
1.2 定义Transformer
Transformer定义在models/transformer.py文件中,由Transformer类实现,如下图,主要包含encoder和decoder,分别由TransformerEncoder类和TransformerDecoder类定义,这两个类都是由encoder_layer(block)和decoder_layer(block)的堆叠而成,所以直接看TransformerEncoderLayer类和TransformerDecoderLayer类即可,如下:
可以看到该类定义了两种前向传播方式,一个是layer_norm在操作之前做,一个是layer_norm在操作之后做。总体步骤都是首先融合特征图(src)与位置编码,然后做多头注意力,最后经过两个全连接层操作。
同理对于TransformerDecoderLayer类也一样
1.3 定义DETR结构
对于DETR类,就是直接利用backbone和transformer构建起来的,如下:
红框处理的DETR的核心,图像经过backbone得到特征图和位置编码,特征图经过卷积降维和位置编码一起输入transformer中,最后经class_embed和bbox_embed做分类和bounding box预测。
2、定义数据集
从datasets/_init_.py文件的build_dataset函数可以看出,代码定义了两种数据读取方式,即“coco”和“coco_panoptic”,以常用的“coco”数据读取方式为例,进入到datasets/coco.py文件中,主要是CoCoDetection类负责读取:
再进入里面一层,实际是ConvertCoCoPolysToMask类负责读取,并且通过return_masks参数切换是否用于分割训练用的标注方式。
数据增强
在读取了图像和对应标注后,就做了数据增强,调用了make_coco_transforms函数对数据集做了水平翻转,随机resize,随机crop,随机scale等操作
3、训练
训练的重要代码都在engine.py文件中的train_one_epoch函数中,如下: