import torchclass MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads):super(MaxState, self).__init__()assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."self.head_size = hidden_dim // headsself.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.hidden = hidden_dimdef forward(self, input_data, state=None):b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_sizeout = self.head0(input_data)out1 = self.head1(input_data)out2 = self.head2(input_data)out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])out = torch.cummax((out + out1) / h ** 0.5, 2)[0]out = out.permute([0, 2, 1, 3])out1 = out1.permute([0, 2, 1, 3])# out2 = out2.permute([0, 2, 1, 3])out = out.reshape([b, s, -1])out1 = out1.reshape([b, s, -1])# out2 = out2.reshape([b, s, -1])# out = self.layer_nor(out)# out = (out + out2) * out+out1# out3=torch.cummax(out,1)[0]out = (out + out2) * out + out1# out = self.alpha * out * (out + out2) + (1 - self.alpha) * out1return out, stateclass FeedForward(torch.nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = torch.nn.Linear(hidden_size, hidden_size // 2)self.ffn2 = torch.nn.Linear(hidden_size //2, hidden_size)self.gate = torch.nn.Linear(hidden_size, hidden_size // 2)self.relu = torch.nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2x = self.ffn2(xx)return xclass DecoderLayer(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(DecoderLayer, self).__init__()self.self_attention = MaxState(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.layer_norm = torch.nn.LayerNorm(hidden_size)self.alpha = torch.nn.Parameter(torch.tensor(0.5))def forward(self, x, state=None, ):x1, state = self.self_attention(x, state)x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)return x, stateclass SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)# self.pos = torch.nn.Embedding(1024, hidden_size)self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])# self.head = torch.nn.Linear(hidden_size, voc_size, False)## self.down = torch.nn.ModuleList(# [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])def state_forward(self, state, x):if state is None:state = [None] * len(self.decoder_layers)i = 0for ii, decoder_layer in enumerate(self.decoder_layers):# x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))x1, state[i] = decoder_layer(x, state[i])x = x1 + xi += 1return x, state# def pos_forward(self, x):# if x.shape[1] >= 1024:# pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)# pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos## else:# pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)# return posdef forward(self, x, state=None):x = self.em(x)# pos = self.pos_forward(x)x, state = self.state_forward(state, x)# return self.head(x), statereturn x@self.em.weight.permute([1,0]), statedevice = "cuda"
if __name__ == '__main__':net = SamOut(235, 256, 16, 4)net.to(device)net(torch.randint(0, 200, [2, 8 * 13]).to(device))#
该代码定义了一个基于PyTorch的神经网络模型,其核心是一个自定义的解码器层(DecoderLayer
),它使用了多头机制(MaxState
)和前馈网络(FeedForward
)。整个模型(SamOut
)用于处理序列数据,比如自然语言处理中的文本。让我们逐个解析这些组件,并讨论输入输出层参数共享的优势。
组件解析
-
MaxState:
- 定义了一个带有多个注意力头的模块,每个头都有自己的线性变换(
head0
,head1
,head2
)。 - 使用了累积最大值函数(
cummax
)来聚合信息。 - 输出维度被重新调整为批量大小、序列长度、头数和头尺寸的形式。
- 定义了一个带有多个注意力头的模块,每个头都有自己的线性变换(
-
FeedForward:
- 一个标准的前馈网络,包括两个线性层和一个门控机制(通过ReLU激活函数实现),用于在解码器层内部处理信息。
-
DecoderLayer:
- 包含一个
MaxState
实例和一个FeedForward
实例。 - 应用了层归一化(
LayerNorm
)以稳定训练过程。 - 使用可学习参数
alpha
来控制来自前馈网络和输入的加权和。
- 包含一个
-
SamOut:
- 整合了所有上述组件,形成了完整的模型。
- 使用嵌入层(
Embedding
)将词汇表索引转换为密集向量表示。 - 模型包含多个解码器层,由
ModuleList
管理。 - 最终输出通过与嵌入层权重的矩阵乘法得到,而不是使用单独的线性层作为输出层。
输入输出层参数共享的优势
在SamOut
类中,输出层并没有显式定义为一个线性层,而是直接通过嵌入层的权重转置进行计算(x @ self.em.weight.permute([1, 0])
)。这种做法通常被称为“权重绑定”或“参数共享”,它具有以下优势:
- 减少参数数量:由于不引入新的权重,模型的总参数量减少,这有助于降低过拟合的风险。
- 加速训练:较少的参数意味着更少的计算资源需求,可以加快训练速度。
- 一致性:输入层和输出层共享相同的权重,保证了模型在处理输入时学到的特征和生成输出时所依赖的特征之间的一致性。
- 简化架构:无需额外定义输出层,简化了模型架构。
这种技术特别适用于词汇预测任务,如语言模型或机器翻译,其中输入和输出都在同一个词汇空间中。通过共享嵌入层和输出层的权重,我们可以使模型更加紧凑和高效,同时保持良好的性能。