您的位置:首页 > 娱乐 > 明星 > 艺术设计网_b2b商务贸易平台_成都seo优化公司排名_搜索网站有哪些

艺术设计网_b2b商务贸易平台_成都seo优化公司排名_搜索网站有哪些

2025/4/5 14:08:31 来源:https://blog.csdn.net/m0_72851153/article/details/146459619  浏览:    关键词:艺术设计网_b2b商务贸易平台_成都seo优化公司排名_搜索网站有哪些
艺术设计网_b2b商务贸易平台_成都seo优化公司排名_搜索网站有哪些

一、源码展示

class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.params = paramsself.vocab_size = params.vocab_sizeself.n_layers = params.n_layersself.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)self.layers = torch.nn.ModuleList()for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads,params.max_seq_len * 2,params.rope_theta,)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):_bsz, seqlen = tokens.shapeh = self.tok_embeddings(tokens)self.freqs_cis = self.freqs_cis.to(h.device)freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)mask = torch.triu(mask, diagonal=1)# When performing key-value caching, we compute the attention scores# only for the new sequence. Thus, the matrix of scores is of size# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for# j > cache_len + i, since row i corresponds to token cache_len + i.mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)output = self.output(h).float()return output

二、原理图

在这里插入图片描述

三、代码注释

class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()# 基本参数self.params = params# 词汇表大小self.vocab_size = params.vocab_size# 模型的层数self.n_layers = params.n_layers# 这个嵌入层会把每个单词映射到一个高维向量,这个高维向量就是这个单词的嵌入。self.tok_embeddings = ParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)# 创建了一个空的模块列表self.layers = torch.nn.ModuleList()# 添加了n_layers个TransformerBlock到列表中for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))# 创建了一个RMSNorm层,它用于对输入数据进行归一化处理。self.norm = RMSNorm(params.dim, eps=params.norm_eps)# ColumnParallelLinear层是一个线性层,用于将输入数据的特征从params.dim维映射到params.vocab_size维。# 这种映射是通过学习一组权重来实现的,权重矩阵的大小为 params.dim x params.vocab_size。# 简言之,将输入转化为params.vocab_size维的输出,这个输出可以看作是预测每个词汇的概率分布。self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)# 计算了freqs_cis,这是一个预计算的张量,用于后面的旋转位置嵌入(Rotary Position Embedding)self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_headers, self.params.max_seq_len * 2,)# 通过torch.inference_mode()装饰器来指示这个方法将用于模型推理,# 这可以帮助PyTorch优化计算,并在可能的情况下减少内存使用。@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):# 批量大小(_bsz)和序列长度(seqlen)_bsz, seqlen = tokens.shape# 词嵌入向量h = self.tok_embeddings(tokens)# 根据输入的序列起始位置start_pos和序列长度seqlen,从self.freqs_cis中取出对应的旋转嵌入。# 这些旋转嵌入将用于后续的Transformer层中,对输入的词嵌入进行旋转操作,以编码位置信息。freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:# 模型首先生成了一个掩码(mask),这个掩码被用于transformer层以防止在自注意力机制中考虑到未来的词汇。mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)# 这是通过填充一个全为负无穷的矩阵,然后使用torch.triu(取上三角)函数,来创建一个遮罩,# 该遮罩对应的位置上的元素,# 如果它们代表的词在序列中是在当前词之后的词,则值为负无穷,否则为0。mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)# 对每个transformer层,依次将当前的嵌入向量(或者前一层的输出)作为输入,# 执行该层的前向传播,计算结果将用于下一层的输入。for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 将最后一层transformer层的输出通过一个规范化(norm)层,然后通过一个全连接层(self.output),# 转换为最后的模型输出。这个输出的尺寸应该与词汇表的大小相同,因此每个词都有一个对应的分数,# 这个分数代表模型认为该词是下一个词的可能性。h = self.norm(h)output = self.output(h).float()return output

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com