您的位置:首页 > 财经 > 产业 > PolyGen: An Autoregressive Generative Model of 3D Meshes代码polygen_decoder.py解读

PolyGen: An Autoregressive Generative Model of 3D Meshes代码polygen_decoder.py解读

2024/7/6 20:20:01 来源:https://blog.csdn.net/sinat_39783664/article/details/140160233  浏览:    关键词:PolyGen: An Autoregressive Generative Model of 3D Meshes代码polygen_decoder.py解读

论文:PolyGen: An Autoregressive Generative Model of 3D Meshes

首先阅读transformer铺垫知识《Torch中Transformer的中文注释》。

以下为Encoder部分,很简单,小学生都会:

from typing import Dict, Optional, Tuple
import pdbimport torch
import torch.nn as nn
from torch.nn import MultiheadAttention, Linear, Dropout, LayerNorm, ReLU, Parameter
import pytorch_lightning as plfrom .utils import get_clonesclass PolygenDecoderLayer(nn.TransformerDecoderLayer):"""根据Vaswani等人2017年的描述,这是一个解码器模块。它使用了遮蔽自注意力和非遮蔽跨注意力来处理序列化上下文模块。实现了缓存机制以加快解码速度。缓存的作用是存储键值对,这样在解码器的每次前向传递中就不必重新生成这些键值对。参数:d_model: 嵌入向量的大小。nhead: 多头注意力机制中的头数。dim_feedforward: 全连接层的大小。dropout: 在每个连接层后ReLU激活函数之后应用的dropout比率。re_zero: 如果为True,则使用alpha比例因子对残差进行零初始化。构造函数 __init__初始化参数包括d_model(嵌入向量的大小)、nhead(多头注意力机制中的头数)、dim_feedforward(全连接层的大小)、dropout(dropout比率)和re_zero(是否使用re_zero技术)。创建了多个多头注意力模块,包括自注意力(self.self_attn)和跨注意力(self.multihead_attn)。定义了前馈网络所需的线性层和激活函数。设置了层归一化(LayerNorm)和Dropout层。如果re_zero被设置为True,则初始化三个可学习的参数alpha、beta和gamma,用于调整残差连接中的权重,这有助于训练深度网络时的稳定性。"""def __init__(self,d_model: int = 256,nhead: int = 4,dim_feedforward: int = 1024,dropout: float = 0.2,re_zero: bool = True,) -> None:super(PolygenDecoderLayer, self).__init__(d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout)# 初始化多头自注意力和多头跨注意力模块self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)# 初始化前馈网络self.linear1 = Linear(d_model, dim_feedforward)self.dropout = Dropout(dropout)self.linear2 = Linear(dim_feedforward, d_model)# 初始化层归一化self.norm1 = LayerNorm(d_model)self.norm2 = LayerNorm(d_model)self.norm3 = LayerNorm(d_model)# 初始化Dropoutself.dropout1 = Dropout(dropout)self.dropout2 = Dropout(dropout)self.dropout3 = Dropout(dropout)# 初始化激活函数self.activation = ReLU()# 初始化re_zero参数self.re_zero = re_zeroself.alpha = Parameter(data=torch.Tensor([0.0]))self.beta = Parameter(data=torch.Tensor([0.0]))self.gamma = Parameter(data=torch.Tensor([0.0]))def forward(self,tgt: torch.Tensor,memory: Optional[torch.Tensor] = None,tgt_mask: Optional[torch.Tensor] = None,memory_mask: Optional[torch.Tensor] = None,tgt_key_padding_mask: Optional[torch.Tensor] = None,memory_key_padding_mask: Optional[torch.Tensor] = None,cache: Optional[Dict[str, torch.Tensor]] = None,) -> torch.Tensor:"""解码器层的前向传播方法。参数:tgt: 输入序列的张量,形状为 [sequence_length, batch_size, embed_size]。(目标序列)memory: 编码器最后一层的序列张量,形状为 [source_sequence_length, batch_size, embed_size]。(来自编码器的输出)tgt_mask: 目标序列的掩码张量,形状为 [sequence_length, sequence_length]。memory_mask: 存储序列的掩码张量,形状为 [sequence_length, source_sequence_length]。tgt_key_padding_mask: 目标序列的键填充掩码张量,形状为 [batch_size, sequence_length]。memory_key_padding_mask: 存储序列的键填充掩码张量,形状为 [batch_size, source_sequence_length]。cache: 用于快速解码的缓存,格式为 {'k': torch.Tensor, 'v': torch.Tensor} 的字典。返回:输出张量,形状为 [sequence_length, batch_size, embed_size],是一个解码器层的前向循环结果。处理流程:接收tgt(目标序列)、memory(来自编码器的输出)、各种掩码和cache作为输入参数。如果提供了cache,会将当前的tgt追加到保存的键值对中,以便于后续的解码步骤能够利用历史信息。使用自注意力层处理输入tgt,考虑到了tgt_mask和tgt_key_padding_mask。如果re_zero开启,自注意力的输出乘以alpha,控制每层的贡献度。接着使用跨注意力层处理memory,同样考虑掩码,并使用beta调整输出。最后通过前馈网络,其中使用gamma调整前馈网络的输出。每个子层的输出都与输入进行了残差连接,并经过Dropout。NOTE:Cache technique.被用在Transformer模型的解码器层;cache中已经保存了之前解码步骤中的键和值张量。在解码新词时,这些已有的键和值张量会被与当前的tgt(目标序列)张量拼接在一起,形成更新后的键和值张量,然后将这些更新后的键和值张量再次存入cache中,供下一次解码使用"""# 处理缓存以实现快速解码if cache is not None:saved_key = cache["k"]saved_value = cache["v"]key = cache["k"] = torch.cat(tensors=[saved_key, tgt], axis=0)value = cache["v"] = torch.cat(tensors=[saved_value, tgt], axis=0)else:key = tgtvalue = tgt# 应用自注意力层tgt2 = self.norm1(tgt)tgt2 = self.self_attn(tgt, key, value, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]# 如果启用了re_zero,应用比例因子if self.re_zero:tgt2 = tgt2 * self.alpha  # 输出与一个比例因子(self.alpha)相乘,这有助于控制每层的贡献,从而稳定训练过程。tgt = tgt + self.dropout1(tgt2)# 应用跨注意力层if memory is not None:tgt2 = self.norm2(tgt)tgt2 = self.multihead_attn(tgt,memory.float(),memory.float(),attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask,)[0]if self.re_zero:tgt2 = tgt2 * self.betatgt2 = self.dropout2(tgt2)tgt = tgt + tgt2# 应用前馈网络tgt2 = self.norm3(tgt)tgt2 = self.linear1(tgt2)tgt2 = self.activation(tgt2)tgt2 = self.linear2(tgt2)if self.re_zero:tgt2 = tgt2 * self.gammatgt2 = self.dropout(tgt2)tgt = tgt + tgt2return tgtclass PolygenDecoder(pl.LightningModule):"""Polygen的解码器,修改自Pytorch的Transformer解码器实现,增加了缓存机制用于加速解码过程。"""def __init__(self,decoder_layer: nn.TransformerDecoderLayer,num_layers: int,norm: Optional[nn.Module] = None,) -> None:"""Polygen解码器的初始化函数。参数:decoder_layer: nn.TransformerDecoderLayer 类型的 Pytorch 模块,指定解码器层的具体实现。num_layers: 解码器中层的数量。norm: 在解码器的前向传播方法结束时应用的层归一化类型。初始化 (__init__ 方法):构造函数通过 get_clones 函数初始化 layers 属性,创建给定 decoder_layer 的多个副本,每个副本都是 nn.TransformerDecoderLayer 的独立实例。同时设置 num_layers 和 norm 属性。"""super(PolygenDecoder, self).__init__()self.layers = get_clones(decoder_layer, num_layers)  # 获取多个解码器层的副本self.num_layers = num_layers  # 解码器层数self.norm = norm  # 层归一化def forward(self,tgt: torch.Tensor,memory: Optional[torch.Tensor] = None,tgt_mask: Optional[torch.Tensor] = None,memory_mask: Optional[torch.Tensor] = None,tgt_key_padding_mask: Optional[torch.Tensor] = None,memory_key_padding_mask: Optional[torch.Tensor] = None,cache: Optional[Tuple] = None,) -> torch.Tensor:"""解码器层的前向传播方法。参数:tgt: 输入序列的张量,形状为 [sequence_length, batch_size, embed_size]。memory: 编码器最后一层的序列张量,形状为 [source_sequence_length, batch_size, embed_size]。tgt_mask: 目标序列的掩码张量,形状为 [sequence_length, sequence_length]。memory_mask: 存储序列的掩码张量,形状为 [sequence_length, source_sequence_length]。tgt_key_padding_mask: 忽略目标序列中指定填充元素的张量,形状为 [batch_size, sequence_length]。memory_key_padding_mask: 忽略存储序列中指定填充元素的张量,形状为 [batch_size, source_sequence_length]。cache: 用于快速解码的缓存,格式为每个层的 {'k': torch.Tensor, 'v': torch.Tensor} 字典组成的列表。返回值:输出张量,形状为 [sequence_length, batch_size, embed_size],是所有解码器层的前向传播结果。处理流程:遍历每层解码器,将当前的 output 张量应用到这一层上。如果提供了 cache,则获取当前层的缓存 (layer_cache) 并将 tgt_mask 设置为 None。这是因为缓存实际上通过存储过去的键值对来内部处理掩码,使得显式掩码变得不必要。若没有缓存,layer_cache 被设置为 None,表明不应使用任何缓存值。对于每一层,通过将 output 以及 memory 和相关掩码(以及可用的缓存)传入当前层,更新 output 张量。如果设置了 norm,在所有解码器层完成之后,对 output 应用层归一化。"""output = tgtfor i, mod in enumerate(self.layers):if cache is not None:layer_cache = cache[i]  # 使用缓存加速解码tgt_mask = None  # 使用缓存时,目标掩码置空else:layer_cache = None  # 不使用缓存output = mod(  # 将当前层的输入、编码器的输出、相关的掩码和缓存(如果有)作为参数,计算出新的输出张量。output,memory,tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask,cache=layer_cache,)if self.norm is not None:output = self.norm(output)  # 应用层归一化return outputclass TransformerDecoder(pl.LightningModule):def __init__(self,hidden_size: int = 256,fc_size: int = 1024,num_heads: int = 4,layer_norm: bool = True,num_layers: int = 8,dropout_rate: float = 0.2,) -> None:"""Transformer解码器,结合了PolygenDecoderLayer和PolygenDecoder。参数:hidden_size: 嵌入向量的大小。fc_size: 全连接层的大小。num_heads: 多头注意力头的数量。layer_norm: 是否使用层归一化。num_layers: 解码器中的层数。dropout_rate: ReLU之后立即应用的dropout比率。"""super(TransformerDecoder, self).__init__()# 初始化解码器结构self.hidden_size = hidden_sizeself.num_layers = num_layersself.decoder = PolygenDecoder(PolygenDecoderLayer(d_model=hidden_size,nhead=num_heads,dim_feedforward=fc_size,dropout=dropout_rate,),num_layers=num_layers,norm=LayerNorm(hidden_size),  # 使用层归一化)def initialize_cache(self, batch_size) -> Dict[str, torch.Tensor]:"""初始化用于快速解码的缓存。参数:batch_size: 输入批次的大小。返回:cache: 包含特定解码器层键值对的字典列表。"""# 初始化k和v张量k = torch.zeros([0, batch_size, self.hidden_size], device=self.device)v = torch.zeros([0, batch_size, self.hidden_size], device=self.device)cache = [{"k": k, "v": v} for _ in range(self.num_layers)]return cachedef generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:"""为输入序列生成一个目标掩码。为序列自注意力机制生成一个掩码,使得每个元素只能关注它之前的元素(包括自己),这样就防止了信息的未来泄露。生成一个下三角掩码矩阵,用于确保在序列自注意力机制中,每个位置只能看到它之前的元素(包括自己)参数:sz: 输入序列的长度。返回:mask: 形状为[sequence_length, sequence_length]的下三角矩阵。流程:假设我们有一个序列长度sz为5,即我们的序列长度是51. 下三角阵mask = (torch.triu(torch.ones(sz, sz, device=self.device)) == 1).transpose(0, 1)创建一个5x5的全1张量;然后将其传递给torch.triu函数,以获取一个上三角矩阵;转为下三角阵。2. 转换类型并应用掩码布尔型的下三角矩阵转为浮点型; 使用masked_fill函数将False的位置(即矩阵中非下三角的部分)填充为负无穷大(float('-inf')),而将True的位置(即下三角部分)填充为0tensor([[ 0., -inf, -inf, -inf, -inf],[ 0.,  0., -inf, -inf, -inf],[ 0.,  0.,  0., -inf, -inf],[ 0.,  0.,  0.,  0., -inf],[ 0.,  0.,  0.,  0.,  0.]])"""# 生成掩码,确保自注意力只关注不晚于当前位置的位置# 生成一个下三角掩码矩阵,用于确保序列中的每个位置只能注意到该位置之前(包含自身)的位置。mask = (torch.triu(torch.ones(sz, sz, device=self.device)) == 1).transpose(0, 1)mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))return maskdef forward(self,inputs: torch.Tensor,sequential_context_embeddings: Optional[torch.Tensor] = None,cache: Optional[tuple] = None,) -> torch.Tensor:"""Transformer解码器的前向传播方法。输入序列(inputs)、上下文嵌入(sequential_context_embeddings)和缓存(cache)参数:inputs: 形状为[sequence_length, batch_size, embed_size]的张量,代表输入序列。cache: 一个字典列表,每个字典格式为{'k': torch.Tensor, 'v': torch.Tensor},代表各解码器层的缓存。返回:out: 解码器所有层的前向传播结果,形状为[sequence_length, batch_size, embed_size]的张量。NOTE:tgt: A Tensor of shape [sequence_length, batch_size, embed_size]. Represents the input sequence.memory: A Tensor of shape [source_sequence_length, batch_size, embed_size]. Represents the sequence from the last layer of the encoder.tgt_mask: A Tensor of shape [sequence_length, sequence_length]. The mask for the target sequence.  下三角阵,控制模型不能开到`未来`的知识cache: A list of dictionaries in the following format: {'k': torch.Tensor, 'v': torch.Tensor}. Each dictionary in the list represents the cache at the respective decoder layer."""sz = inputs.shape[0]  # 输入序列长度mask = self.generate_square_subsequent_mask(sz)  # 生成自注意力掩码out = self.decoder(inputs, memory=sequential_context_embeddings, tgt_mask=mask, cache=cache)  # 执行前向传播return out

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com