原文:DBNET文字检测 - 知乎 (zhihu.com)
一、DBNET介绍
DBNET核心采用的是基于分割的做法进行文本检测,即将每个文本块都进行语义分割,然后对分割概率图进行简单二值化、最终转化得为box或者poly格式的检测结果。除去网络设计方面的差异,最大特点是引入了Differentiable Binarization(DB)模块来优化分割预测结果。常规的基于语义分割的文本检测算法都是直接输出二值语义概率图或者其他辅助信息,然后经过阈值二值化得到最终结果,要想得到比较好的文本检测效果,一般都需要复杂的后处理,例如PSENet和PANet,会导致速度很慢。DBNET将阈值二值化过程变得可微,这一小小改动不仅可以增加错误预测梯度,也可以联合优化各个分支,得到更好的语义概率图。
二、DBNET算法流程
与常规基于语义分割算法的区别是多了一条threshold map分支,该分支的主要目的是和分割图联合得到更接近二值化的二值图,属于辅助分支。
2.1、backbone
骨架网络采用的是resnet18或者resnet50,为了增加网络特征提取能力,在layer2、layer3和layer4模块内部引入了变形卷积dcnv2模块。
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zooBatchNorm2d = nn.BatchNorm2d__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'deformable_resnet18', 'deformable_resnet50','resnet152']model_urls = {'resnet18': 'https://download.pytorch.org/modelss/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/modelss/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/modelss/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/modelss/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/modelss/resnet152-b121ed2d.pth',
}def constant_init(module, constant, bias=0): # 常量初始化nn.init.constant_(module.weight, constant)if hasattr(module, 'bias'):nn.init.constant_(module.bias, bias)def conv3x3(in_planes, out_planes, stride=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):super(BasicBlock, self).__init__()self.with_dcn = dcn is not Noneself.conv1 = conv3x3(inplanes, planes, stride) # 正常卷积过程self.bn1 = BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.with_modulated_dcn = Falseif not self.with_dcn:self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)else:from torchvision.ops import DeformConv2ddeformable_groups = dcn.get('deformable_groups', 1)offset_channels = 18self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1)self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False)self.bn2 = BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)# out = self.conv2(out)if not self.with_dcn:out = self.conv2(out)else:offset = self.conv2_offset(out)out = self.conv2(out, offset)out = self.bn2(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):super(Bottleneck, self).__init__()self.with_dcn = dcn is not Noneself.conv1 = nn.Con