您的位置:首页 > 房产 > 建筑 > 注意力机制和自注意力机制模块的相关代码实现/缝合模块/即插即用模块

注意力机制和自注意力机制模块的相关代码实现/缝合模块/即插即用模块

2024/10/6 18:22:30 来源:https://blog.csdn.net/GDHBFTGGG/article/details/140698517  浏览:    关键词:注意力机制和自注意力机制模块的相关代码实现/缝合模块/即插即用模块
注意力机制

import torch
import torch.nn as nnclass Attention(nn.Module):def __init__(self, hidden_dim):super(Attention, self).__init__()self.attention = nn.Linear(hidden_dim, 1, bias=False)def forward(self, encoder_outputs):# encoder_outputs shape: (batch_size, sequence_length, hidden_dim)attn_weights = self.attention(encoder_outputs)  # (batch_size, sequence_length, 1)attn_weights = torch.softmax(attn_weights, dim=1)  # (batch_size, sequence_length, 1)context = torch.sum(attn_weights * encoder_outputs, dim=1)  # (batch_size, hidden_dim)return context, attn_weights# 示例用法
batch_size = 2
sequence_length = 5
hidden_dim = 10encoder_outputs = torch.randn(batch_size, sequence_length, hidden_dim)
attention_layer = Attention(hidden_dim)
context, attn_weights = attention_layer(encoder_outputs)print("Context:", context)
print("Attention Weights:", attn_weights)

自注意力机制

import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)  # (batch_size, seq_length, embed_dim * 3)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, 3 * head_dim)

        q, k, v = qkv.chunk(3, dim=-1)  # Each has shape (batch_size, num_heads, seq_length, head_dim)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # Scaled dot-product
        attn_weights = self.softmax(attn_weights)  # (batch_size, num_heads, seq_length, seq_length)

        attn_output = torch.matmul(attn_weights, v)  # (batch_size, num_heads, seq_length, head_dim)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_length, embed_dim)

        output = self.o_proj(attn_output)
        return output, attn_weights

# 示例用法
batch_size = 2
seq_length = 5
embed_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_length, embed_dim)
self_attention_layer = MultiHeadSelfAttention(embed_dim, num_heads)
output, attn_weights = self_attention_layer(x)

print("Output:", output)
print("Attention Weights:", attn_weights)
 

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com