Attention:Multi-head Attention
- 引言
- Multi-head Attention代码
引言
请注意!!!本博客使用了教程Transformers快速入门中的全部代码!!!
只在我个人理解的基础上为代码添加了注释!!!
详细教程请查看Transformers快速入门!!!
万分感谢!!!
自用!!!
Multi-head Attention代码
# Multi-head Attention 首先通过线性映射将 Q, K, V 序列映射到特征空间,
# 每一组线性投影后的向量表示称为一个头 (head),
# 然后在每组映射后的序列上再应用 Scaled Dot-product Attention:from torch import nn
import torch
import torch.nn.functional as F
from math import sqrtfrom transformers import AutoConfig
from transformers import AutoTokenizer# query_mask, key_mask:
#### 用于屏蔽某些查询或键的位置。如果指定,通常是 [batch_size, seq_length] 的张量。
#### 这些掩码可以屏蔽不需要参与计算的序列位置(例如,填充位置)。
# mask:
#### 更通用的掩码,形状通常为 [batch_size, seq_length, seq_length],
#### 用于屏蔽具体的查询-键对。
def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None):dim_k = query.size(-1)scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)# 应用掩码(屏蔽不相关的输入)if query_mask is not None and key_mask is not None:# query_mask.unsqueeze(-1) 将查询掩码扩展为 [batch_size, seq_length, 1]。# key_mask.unsqueeze(1) 将键掩码扩展为 [batch_size, 1, seq_length]。# torch.bmm 计算两者的外积,得到形状为 [batch_size, seq_length, seq_length] 的掩码矩阵 mask。# 掩码矩阵中的值为 1 表示对应的查询-键对有效,0 表示无效。mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))if mask is not None:# 如果存在 mask,使用 masked_fill 方法将 mask == 0 的位置填充为 -inf,# 这些位置的得分被屏蔽。# Softmax 在处理 -inf 时会将其归一化为 0,从而屏蔽这些位置。scores = scores.masked_fill(mask == 0, -float("inf"))weights = F.softmax(scores, dim=-1)return torch.bmm(weights, value)# AttentionHead:
#### 每个注意力头会独立地从输入中学习查询(query)、键(key)和值(value)的表示,
#### 并通过注意力机制聚合上下文信息。
# nn.Module:
#### 继承自 PyTorch 的 nn.Module,是构建神经网络的基础类。
#### 提供了模块参数管理和自动求导的能力。
class AttentionHead(nn.Module):# embed_dim:#### 输入嵌入的维度。#### 表示输入序列中每个 token 的特征向量大小。# head_dim:#### 注意力头的维度。#### 每个注意力头会将输入嵌入维度 embed_dim 映射到更低的 head_dim 维度。# nn.Linear:#### 定义了三个线性变换层:######## self.q:将输入映射到查询(query)向量。######## self.k:将输入映射到键(key)向量。######## self.v:将输入映射到值(value)向量。#### 每个线性层的参数如下:######## 输入维度:embed_dim######## 输出维度:head_dim#### 通过这些线性变换,将原始输入嵌入的特征空间变换为注意力机制需要的特征空间。# =================================================================================================# 降维/升维(通过线性变换)def __init__(self, embed_dim, head_dim): # 768, 64super().__init__()self.q = nn.Linear(embed_dim, head_dim)self.k = nn.Linear(embed_dim, head_dim)self.v = nn.Linear(embed_dim, head_dim)# self.q(query):将查询向量通过线性层映射到头的特征空间,输出形状为 [batch_size, seq_length, head_dim]。# self.k(key) 和 self.v(value):同理,将键和值映射到特征空间。def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):attn_outputs = scaled_dot_product_attention(self.q(query), self.k(key), self.v(value), query_mask, key_mask, mask)return attn_outputs# 每个注意力头独立计算注意力分布,最终将所有头的结果连接起来并通过一个线性变换聚合。
class MultiHeadAttention(nn.Module):# config:#### 一个配置对象,通常用于存储模型的超参数。#### 必须包含以下字段:######## hidden_size:输入嵌入的总维度。######## num_attention_heads:注意力头的数量。def __init__(self, config):super().__init__()embed_dim = config.hidden_sizeprint("embed_dim的数量:")print(embed_dim)print("==================")num_heads = config.num_attention_headsprint("num_heads的数量:")print(num_heads)print("==================")# 实践中一般将 head_dim 设置为 embed_dim 的因数,# 这样 token 嵌入式表示的维度就可以保持不变,# 例如 BERT 有 12 个注意力头,因此每个头的维度被设置为 768 / 12 = 64# 确保总的注意力头输出维度与输入嵌入维度一致。head_dim = embed_dim // num_headsprint("head_dim的数量:")print(head_dim)print("==================")# 使用 nn.ModuleList 创建多个 AttentionHead 实例。# 每个 AttentionHead 独立计算注意力分布和上下文聚合。self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])# 定义一个线性层,用于将所有注意力头的结果进行聚合和变换。# 输入维度和输出维度都是 embed_dim,确保注意力模块的输入和输出形状一致。self.output_linear = nn.Linear(embed_dim, embed_dim)# query, key, value:#### 输入序列的查询、键和值向量。#### 形状为 [batch_size, seq_length, embed_dim]。# query_mask, key_mask, mask:#### 掩码,用于屏蔽不需要计算的查询或键,避免影响注意力计算。def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):# 遍历每个注意力头(self.heads),调用 AttentionHead 的 forward 方法。# 每个注意力头独立地处理查询、键和值,并返回其上下文输出。# 每个头的输出形状为 [batch_size, seq_length, head_dim]。# 将所有头的输出在最后一个维度(dim=-1)拼接起来。# 拼接后,形状为 [batch_size, seq_length, embed_dim],因为:# embed_dim = num_heads \times head_dimx = torch.cat([h.forward(query, key, value, query_mask, key_mask, mask) for h in self.heads], dim=-1)print("x的size:")print(x.size())print("==================")# 使用 self.output_linear 对拼接后的结果进行线性变换。# 线性变换可以引入头之间的交互信息,并生成最终的注意力输出。# 输出形状仍为 [batch_size, seq_length, embed_dim]。x = self.output_linear(x)return xmodel_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)text = "time flies like an arrow"
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
inputs_embeds = token_emb(inputs.input_ids)multihead_attn = MultiHeadAttention(config)
query = key = value = inputs_embeds # 完全一样
print("query的size:")
print(query.size())
print("==================")
# attn_output = multihead_attn(query, key, value)
attn_output = multihead_attn.forward(query, key, value)
print("output的size:")
print(attn_output.size())
print("==================")
>>>
embed_dim的数量:
768
==================
num_heads的数量:
12
==================
head_dim的数量:
64
==================
query的size:
torch.Size([1, 5, 768])
==================
x的size:
torch.Size([1, 5, 768])
==================
output的size:
torch.Size([1, 5, 768])
==================