import torch from torch import nn from typing import Optional class BertEmbeddings(nn.Module):"""Construct the embeddings from word, position and token_type embeddings."""def __init__(self, config):super().__init__()self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load# any TensorFlow checkpoint fileself.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)# position_ids (1, len position emb) is contiguous in memory and exported when serializedself.position_embedding_type = getattr(config, "position_embedding_type", "absolute")self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False)def forward(self,input_ids: Optional[torch.LongTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,past_key_values_length: int = 0,) -> torch.Tensor:if input_ids is not None:input_shape = input_ids.size()else:input_shape = inputs_embeds.size()[:-1]seq_length = input_shape[1]if position_ids is None:position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves# issue #5664if token_type_ids is None:if hasattr(self, "token_type_ids"):buffered_token_type_ids = self.token_type_ids[:, :seq_length]buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)token_type_ids = buffered_token_type_ids_expandedelse:token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)if inputs_embeds is None:inputs_embeds = self.word_embeddings(input_ids)token_type_embeddings = self.token_type_embeddings(token_type_ids)embeddings = inputs_embeds + token_type_embeddingsif self.position_embedding_type == "absolute":position_embeddings = self.position_embeddings(position_ids)embeddings += position_embeddingsembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddings # 配置类 class BertConfig:def __init__(self):self.vocab_size = 30522 # BERT 基础模型的词表大小self.hidden_size = 768 # 隐藏层维度self.pad_token_id = 0 # 填充token的IDself.max_position_embeddings = 512 # 最大位置编码长度self.type_vocab_size = 2 # token类型数量(通常为2:句子A和句子B)self.layer_norm_eps = 1e-12 # LayerNorm的epsilon值self.hidden_dropout_prob = 0.1 # dropout概率# 创建配置实例 config = BertConfig()# 初始化BertEmbeddings embeddings = BertEmbeddings(config)# 示例1:基本输入(使用input_ids) input_ids = torch.tensor([[101, 2054, 2003, 102], # [CLS] Hello world [SEP][101, 2023, 4248, 102] # [CLS] How are [SEP] ]) # 形状 (batch_size=2, seq_length=4)# 前向传播 output = embeddings(input_ids=input_ids,token_type_ids=None, # 自动生成全零position_ids=None, # 自动从position_ids缓冲区获取inputs_embeds=None, # 使用input_idspast_key_values_length=0 # 无历史token ) print(f"输出形状: {output.shape}") # 应为 torch.Size([2, 4, 768])# 示例2:使用预计算的inputs_embeds inputs_embeds = torch.rand(2, 4, config.hidden_size) # 随机初始化嵌入 output = embeddings(input_ids=None, # 使用inputs_embedsinputs_embeds=inputs_embeds ) print(f"输出形状: {output.shape}") # 应为 torch.Size([2, 4, 768])# 示例3:自定义token_type_ids(句子对任务) token_type_ids = torch.tensor([[0, 0, 0, 1], # 前3个token属于句子A,最后1个属于句子B[0, 0, 1, 1] # 前2个token属于句子A,后2个属于句子B ]) output = embeddings(input_ids=input_ids,token_type_ids=token_type_ids )# 示例4:生成任务中使用past_key_values_length # 假设已生成3个token,当前输入长度为1 output = embeddings(input_ids=torch.tensor([[2054]]), # 当前token: "Hello"past_key_values_length=3 # 已生成3个token ) print(f"输出形状: {output.shape}") # 应为 torch.Size([1, 1, 768])