您的位置:首页 > 健康 > 美食 > 渝北网站建设_太原网站制作_合理使用说明_合肥网站优化公司

渝北网站建设_太原网站制作_合理使用说明_合肥网站优化公司

2024/12/23 15:37:14 来源:https://blog.csdn.net/tianyunlinger/article/details/143982820  浏览:    关键词:渝北网站建设_太原网站制作_合理使用说明_合肥网站优化公司
渝北网站建设_太原网站制作_合理使用说明_合肥网站优化公司
from typing import Tuple
import torchdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):"""Helper function to reshape frequency tensor to have the same shape as the target tensor 'x'for the purpose of broadcasting the frequency tensor during element-wise operations.Args:freqs_cis (torch.Tensor): Frequency tensor to be reshaped.x (torch.Tensor): Target tensor for broadcasting compatibility.Returns:torch.Tensor: Reshaped frequency tensor.Raises:AssertionError: If the frequency tensor doesn't match the expected shape.AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions."""ndim = x.ndimassert 0 <= 1 < ndimassert freqs_cis.shape == (x.shape[1], x.shape[-1])shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]return freqs_cis.view(shape)
#########填充维度,方便计算def apply_rotary_emb(query: torch.Tensor,key: torch.Tensor,head_dim: int,max_seq_len: int,theta: float = 10000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:"""Apply rotary embeddings to input tensors using the given frequency tensor.Args:query (torch.Tensor): Query tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_heads, head_dim)key (torch.Tensor): Key tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_kv_heads, head_dim)head_dim (int): Dimension of each attention head.max_seq_len (int): Maximum sequence length supported by model.Returns:Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings."""_, seqlen, _, _ = query.shape  # 获取查询张量的形状参数device = query.device  # 获取查询张量的设备信息(如在 CPU 或 GPU 上)seq_len, batch_size, num_heads = query.size(1), query.size(0), query.size(2)  # 获取序列长度、批次大小和头部数量# reshape xq and xk to match the complex representationquery_real, query_imag = query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1)  # 将查询张量重塑并分为实部和虚部key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1)  # 将键张量重塑并分为实部和虚部inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2.0, device=device) / head_dim))pos_seq = torch.arange(0, seqlen, device=device)sinusoid_inp = torch.einsum("i,j->ij", pos_seq, inv_freq)  #sin = torch.sin(sinusoid_inp)cos = torch.cos(sinusoid_inp)# Use the reshape_for_broadcast function to reshape cos and sin terms for broadcastingcos_rotations = reshape_for_broadcast(cos, query_real)  # 调整余弦值张量的形状以进行广播sin_rotations = reshape_for_broadcast(sin, query_imag)  # 调整正弦值张量的形状以进行广播# Apply the rotations to the real and imaginary partsquery_rot_real = cos_rotations * query_real - sin_rotations * query_imag  # 应用旋转到查询张量的实部query_rot_imag = sin_rotations * query_real + cos_rotations * query_imag  # 应用旋转到查询张量的虚部key_rot_real = cos_rotations * key_real - sin_rotations * key_imag  # 应用旋转到键张量的实部key_rot_imag = sin_rotations * key_real + cos_rotations * key_imag  # 应用旋转到键张量的虚部# Reassemble the real and imaginary parts back into the original format# query_out = torch.cat([query_rot_real, query_rot_imag], dim=-1).view_as(query)  # 重新组合并调整查询张量的形状# key_out = torch.cat([key_rot_real, key_rot_imag], dim=-1).view_as(key)  # 重新组合并调整键张量的形状query_out = torch.stack((query_rot_real, query_rot_imag), dim=-1).flatten(-2)key_out = torch.stack((key_rot_real, key_rot_imag), dim=-1).flatten(-2)return query_out, key_out  # 返回包含旋转位置嵌入的查询和键张量

在这里插入图片描述

上述代码和b站这个up讲的,或者一般的rope代码有两点不同
1,q0,q1,q2,q3…和用两个相同cos,sin张量堆叠起来的新张量点乘的操作,变成先将张量q分离成q0,q2,q4…和q1,q3,q5…两个张量去和相同的cos,sin张量点乘
2,补全张量维度由代码

cos_cached idx_theta2.cos()[:,None,None,:]
sin_cached idx_theta2.sin()[:,None,None,:]

变成

cos_rotations = reshape_for_broadcast(cos, query_real)  # 调整余弦值张量的形状以进行广播
sin_rotations = reshape_for_broadcast(sin, query_imag)  # 调整正弦值张量的形状以进行广播

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com