一、源码展示
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.params = paramsself.vocab_size = params.vocab_sizeself.n_layers = params.n_layersself.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)self.layers = torch.nn.ModuleList()for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads,params.max_seq_len * 2,params.rope_theta,)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):_bsz, seqlen = tokens.shapeh = self.tok_embeddings(tokens)self.freqs_cis = self.freqs_cis.to(h.device)freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)mask = torch.triu(mask, diagonal=1)mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)output = self.output(h).float()return output
二、原理图

三、代码注释
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.params = paramsself.vocab_size = params.vocab_sizeself.n_layers = params.n_layersself.tok_embeddings = ParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)self.layers = torch.nn.ModuleList()for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_headers, self.params.max_seq_len * 2,)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):_bsz, seqlen = tokens.shapeh = self.tok_embeddings(tokens)freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)output = self.output(h).float()return output