您的位置:首页 > 游戏 > 游戏 > 中信建设有限责任公司建筑院_郑州做网站设计_推广文案怎么写_天津百度网站排名优化

中信建设有限责任公司建筑院_郑州做网站设计_推广文案怎么写_天津百度网站排名优化

2025/4/22 20:43:36 来源:https://blog.csdn.net/qq_45812220/article/details/147347407  浏览:    关键词:中信建设有限责任公司建筑院_郑州做网站设计_推广文案怎么写_天津百度网站排名优化
中信建设有限责任公司建筑院_郑州做网站设计_推广文案怎么写_天津百度网站排名优化

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5# 查询、键、值投影self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# 投影q = self.q_proj(x)  # (batch, seq_len, embed_dim)k = self.k_proj(x)  # (batch, seq_len, head_dim)v = self.v_proj(x)  # (batch, seq_len, head_dim)# 重塑查询为多头q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# (batch, num_heads, seq_len, head_dim)# 键和值保持单头,扩展到多头维度k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# (batch, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)# 合并多头out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # (batch, seq_len, embed_dim)return out# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com