Attention with Bilinear Correlation for Infrared Small Target Detection
论文地址:
问题:
解决方案:
注意力机制的适用性:
即插即用代码:
论文地址:
https://arxiv.org/pdf/2303.10321https://arxiv.org/pdf/2303.10321
问题:
红外小目标检测 (ISTD) 面临着以下挑战:
-
目标特征弱: 红外小目标由于成像距离远、红外辐射能量衰减等因素,缺乏清晰的轮廓和纹理特征,容易被背景噪声和杂波淹没,导致目标信息提取困难。
-
深度学习方法的局限性:
-
基于 CNN 的方法: CNN 擅长提取局部特征,但受限于感受野大小,难以捕捉全局信息,容易受到噪声干扰,导致目标误检和漏检。
-
基于 Transformer 的方法: Transformer 具有强大的全局建模能力,但由于缺乏卷积操作的归纳偏置,难以有效提取局部特征,且在多次下采样后容易丢失目标信息,导致目标检测性能下降。
-
解决方案:
论文提出了一个名为 ABC 的新型模型,该模型结合了 CNN 和 Transformer 的优势,有效地解决了上述问题:
-
卷积线性融合 Transformer (CLFT) 模块:
-
Transformer 结构: 该模块基于 Transformer 结构,利用其全局建模能力,可以有效地捕捉图像中的长距离依赖关系,从而提取全局特征。
-
重新设计的自注意力机制: 该模块重新设计了自注意力机制,通过引入双线性注意力模块 (BAM) 计算注意力矩阵,可以更有效地关注目标区域,抑制背景噪声和杂波。
-
卷积和扩张卷积: 该模块结合了卷积和扩张卷积,可以有效地提取局部特征和全局特征,从而更全面地描述目标信息。
-
特征融合: 该模块将注意力机制提取的全局特征与卷积操作提取的局部特征进行融合,从而增强目标特征并抑制噪声。
-
-
U 型卷积-扩张卷积 (UCDC) 模块:
-
U 型结构: 该模块采用 U 型结构,通过跳跃连接将编码器中不同层次的特征与解码器中的特征进行融合,从而实现多尺度特征融合,提高目标检测的准确性。
-
卷积和扩张卷积: 该模块结合了卷积和扩张卷积,可以有效地提取不同尺度的特征,从而更精细地描述目标信息。
-
特征细化: 该模块位于网络的更深层次,利用深度特征分辨率较小的特点,可以提取更细粒度的语义信息,进一步细化目标特征,提高目标检测的精度。
-
注意力机制的适用性:
-
语义分割: ABC 模型中的注意力机制可以有效地关注目标区域,抑制背景噪声和杂波,从而提高语义分割的准确性。例如,在医学图像分割、遥感图像分割等任务中,注意力机制可以帮助网络更好地识别目标区域,从而提高分割结果的准确性。
-
目标检测: 注意力机制可以用于目标检测的各个阶段,例如:
-
特征提取阶段: 注意力机制可以引导网络关注目标区域,提取更有效的特征,从而提高目标检测的准确性。例如,在基于 CNN 的目标检测模型中,注意力机制可以与特征金字塔网络 (FPN) 结合,从而更有效地提取多尺度特征。
-
目标定位阶段: 注意力机制可以聚焦于目标中心区域,提高目标定位的精度。例如,在基于回归的目标检测模型中,注意力机制可以帮助网络更准确地预测目标的位置。
-
目标分类阶段: 注意力机制可以关注目标的关键区域,从而提高目标分类的准确率。例如,在基于 R-CNN 的目标检测模型中,注意力机制可以帮助网络更准确地识别目标类别。
-
ABC 模型通过结合 CNN 和 Transformer 的优势,有效地解决了红外小目标检测中的特征损失和噪声干扰问题,取得了优异的性能。其注意力机制也可以应用于其他视觉任务,例如语义分割和目标检测,提高模型的性能。
即插即用代码:
import torch
import torch.nn as nn
from einops import rearrange
def conv_relu_bn(in_channel, out_channel, dirate):return nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=dirate,dilation=dirate),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True))#bilinear attention module (BAM)
class BAM(nn.Module):def __init__(self, in_dim, in_feature, out_feature):super(BAM, self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)self.query_line = nn.Linear(in_features=in_feature, out_features=out_feature)self.key_line = nn.Linear(in_features=in_feature, out_features=out_feature)self.s_conv = nn.Conv2d(in_channels=1, out_channels=in_dim, kernel_size=1)self.softmax = nn.Softmax(dim=-1)def forward(self, x):q = rearrange(self.query_line(rearrange(self.query_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b h 1')k = rearrange(self.key_line(rearrange(self.key_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b 1 h')att = rearrange(torch.matmul(q, k), 'b h w -> b 1 h w')att = self.softmax(self.s_conv(att))return attclass Conv(nn.Module):def __init__(self, in_dim):super(Conv, self).__init__()self.convs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, 1) for _ in range(3)])def forward(self, x):for conv in self.convs:x = conv(x)return x#dilated convolution layers(DConv)
class DConv(nn.Module):def __init__(self, in_dim):super(DConv, self).__init__()dilation = [2, 4, 2]self.dconvs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, dirate) for dirate in dilation])def forward(self, x):for dconv in self.dconvs:x = dconv(x)return xclass ConvAttention(nn.Module):def __init__(self, in_dim, in_feature, out_feature):super(ConvAttention, self).__init__()self.conv = Conv(in_dim)self.dconv = DConv(in_dim)self.att = BAM(in_dim, in_feature, out_feature)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):q = self.conv(x)k = self.dconv(x)v = q + katt = self.att(x)out = torch.matmul(att, v)return self.gamma * out + v + xclass FeedForward(nn.Module):def __init__(self, in_dim, out_dim):super(FeedForward, self).__init__()self.conv = conv_relu_bn(in_dim, out_dim, 1)# self.x_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1)self.x_conv = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1),nn.BatchNorm2d(out_dim),nn.ReLU(inplace=True))def forward(self, x):out = self.conv(x)x = self.x_conv(x)return x + out#convolution linear fusion transformer (CLFT)
class CLFT(nn.Module):def __init__(self, in_dim, out_dim, in_feature, out_feature):super(CLFT, self).__init__()self.attention = ConvAttention(in_dim, in_feature, out_feature)self.feedforward = FeedForward(in_dim, out_dim)def forward(self, x):x = self.attention(x)out = self.feedforward(x)return outif __name__ == '__main__':block = CLFT(64,64,32*32,32) # 输入通道数,输出通道数 图像大小 H*W,H or Winput = torch.randn(3, 64, 32, 32) #输入tensor形状 B C H W# Print input shapeprint(input.size()) # 输入形状# Pass the input tensor through the modeloutput = block(input)# Print output shapeprint(output.size()) # 输出形状
大家对于YOLO改进感兴趣的可以进群了解,群中有答疑,(QQ群:828370883)