您的位置:首页 > 新闻 > 资讯 > 宁波seo外包方案_成都哪里好玩适合小孩_西安seo建站_西安seo外包行者seo06

宁波seo外包方案_成都哪里好玩适合小孩_西安seo建站_西安seo外包行者seo06

2024/12/23 22:09:12 来源:https://blog.csdn.net/2401_86807530/article/details/144515826  浏览:    关键词:宁波seo外包方案_成都哪里好玩适合小孩_西安seo建站_西安seo外包行者seo06
宁波seo外包方案_成都哪里好玩适合小孩_西安seo建站_西安seo外包行者seo06

# 混合注意力机制(Hybrid Attention Mechanism)是一种结合空间和通道注意力的策略,旨在提高神经网络的特征提取能力。

# 空间和通道都加上去

# CBAM是一种轻量级的注意力模块,它通过增加空间和通道两个维度的注意力,来提高模型的性能。

# 在某个阶段 先后加入通道和空间

import torch

import torch.nn as nn

# CBAM 混合注意力 方法 的实现

# 通道注意力构建

class ChannelAtt(nn.Module):

    def __init__(self,c,r= 16,*args, **kwargs):

        super().__init__(*args, **kwargs)

        self.max=nn.Sequential(nn.AdaptiveMaxPool2d(1),nn.ReLU())

        self.avg=nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.ReLU())

        # 感知机 两个池化完成之后 得到通道系数的过程 这两个结果是共用的

        self.perceptron=nn.Sequential(nn.Linear(c,c//r),nn.ReLU(),nn.Linear(c//r,c))

        self.activate=nn.Sigmoid()

    def forward(self,x):

        x1=self.max(x)

        x2=self.avg(x)

        x1=self.perceptron(x1.view(x1.shape[0],-1))

        x2=self.perceptron(x2.view(x2.shape[0],-1))

        att=self.activate(x1+x2)

        att=att.unsqueeze(2).unsqueeze(3)

        return x*att

# 空间注意力构建

class SpaceAtt(nn.Module):

    def __init__(self,kernel_size=7, *args, **kwargs):

        super().__init__(*args, **kwargs)

        self.fc=nn.Sequential(nn.Conv2d(in_channels=2,out_channels=1,kernel_size=kernel_size,padding=kernel_size//2),nn.Sigmoid())

       

    def forward(self,x):

        x1=torch.max(x,dim=1,keepdim=True)# 当 keepdim=True 时,计算均值后,输出张量将保持被约简的维度,但该维度的大小将为 1。也就是说,结果张量会保留原始张量的形状结构,只是被约简的维度变为 1。

        x2=torch.mean(x,dim=1,keepdim=True)

        att=self.fc(torch.cat((x1.values,x2),dim=1))

        return x*att

class CBAM(nn.Module):

    def __init__(self, c,r=16,*args, **kwargs):

        super().__init__(*args, **kwargs)

        self.channel_att=ChannelAtt(c,r)

        self.space_att=SpaceAtt()

    def forward(self,x):

        x=self.channel_att(x)

        x=self.space_att(x)

        return x

img=torch.rand(1,128,224,224)

cbam=CBAM(img.shape[1])

res=cbam(img)

print(res.shape)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com