import torchclass MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads):super(MaxState, self).__init__()assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."self.head_size = hidden_dim // headsself.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)# self.h_linear=torch.nn.Parameter(torch.empty(1, 1))# torch.nn.init.xavier_uniform_(self.h_linear,0.5)# self.layer_nor = torch.nn.LayerNorm(hidden_dim)# self.norm = torch.nn.LayerNorm(hidden_dim)# self.alpha = torch.nn.Parameter(torch.tensor(0.5))self.head_num = headsself.hidden = hidden_dimself.layer_nor = torch.nn.LayerNorm(hidden_dim)def forward(self, input_data, state=None):# self.head.to(device)b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_sizeout = self.head0(input_data)out1 = self.head1(input_data)out2 = self.head2(input_data)out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])# out2 = out2.reshape([b, s, k, h]).permute([0, 2, 1, 3])# out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])out = torch.cummax((out + out1) / h ** 0.5, 2)[0]# out = torch.cummin((out + out1)/k**0.5 , 2)[0]# out_sum = torch.cumsum((out + out1)/k**0.5 , 2)# out=(out-out_min)*outout = out.permute([0, 2, 1, 3])out1 = out1.permute([0, 2, 1, 3])# out2 = out2.permute([0, 2, 1, 3])out = out.reshape([b, s, -1])out1 = out1.reshape([b, s, -1])# out2 = out2.reshape([b, s, -1])# out = self.layer_nor(out)# out = (out + out2) * out+out1# out3=torch.cummax(out,1)[0]# out = (out + out2) * out + out1out = self.layer_nor(out + out2 + out1)# out = self.alpha * out * (out + out2) + (1 - self.alpha) * out1return outclass FeedForward(torch.nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)# self.h_linear=torch.nn.Parameter(torch.empty(1, 1))# self.gate = torch.nn.Parameter(torch.empty(hidden_size, hidden_size * 2))# torch.nn.init.xavier_uniform_(self.gate,0.5)self.relu = torch.nn.ReLU()self.dr = torch.nn.Dropout(0.1)def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = self.dr(x1 * x2)x = self.ffn2(xx)return xclass SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)self.pos = torch.nn.Embedding(1024, hidden_size)self.state = MaxState(hidden_size, num_heads)self.state1 = MaxState(hidden_size, num_heads)self.state2 = MaxState(hidden_size, num_heads)self.state3 = MaxState(hidden_size, num_heads)self.state4 = MaxState(hidden_size, num_heads)self.state5 = MaxState(hidden_size, num_heads)self.decoder = FeedForward(hidden_size)self.decoder1 = FeedForward(hidden_size)self.decoder2 = FeedForward(hidden_size)self.decoder3 = FeedForward(hidden_size)self.decoder4 = FeedForward(hidden_size)self.decoder5 = FeedForward(hidden_size)self.head = torch.nn.Linear(hidden_size, voc_size, False)self.layer_nor=torch.nn.LayerNorm(hidden_size)def pos_forward(self, x):if x.shape[1] >= 1024:pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + poselse:pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)return posdef forward(self, x):x = self.em(x)pos = self.pos_forward(x)x = self.state(x + pos) + xx1 = self.decoder(x)x2 = self.state1(x1) + x1x2 = self.decoder1(x2)x3 = self.state2(x1) + x1x3 = self.decoder2(x3)x = self.layer_nor(x1 + x2 + x3)x = self.state3(x) + xx1 = self.decoder3(x)x2 = self.state4(x1) + x1x2 = self.decoder4(x2)x3 = self.state5(x1) + x1x3 = self.decoder5(x3)x = self.layer_nor(x1 + x2 + x3)return self.head(x), ""device = "cuda"
if __name__ == '__main__':net = SamOut(235, 256, 16, 4)net.to(device)net(torch.randint(0, 200, [2, 8 * 13]).to(device))#
# epoch___0____loss___8.586270____steps___65760: 0%| | 0/1 [01:21<?, ?it/s] cummax
# epoch___0____loss___6.930531____steps___67040: 0%| | 0/1 [01:21<?, ?it/s] cummax no layer_nor
# epoch___0____loss___7.680687____steps___77840: 0%| | 0/1 [01:35<?, ?it/s] cummax layer_nor
# epoch___0____loss___6.994579____steps___68240: 0%| | 0/1 [01:25<?, ?it/s] cummax cos
# epoch___0____loss___6.707716____steps___70640: 0%| | 0/1 [01:24<?, ?it/s] cummax no sin no cos
# epoch___0____loss___6.895388____steps___65200: 0%| | 0/1 [01:21<?, ?it/s] cummin
# epoch___0____loss___7.079460____steps___66720: 0%| | 0/1 [01:22<?, ?it/s] cummax no x
# epoch___0____loss___6.174834____steps___45360: 0%| | 0/10 [01:00<?, ?it/s] cummax 2 2 no pos
# epoch___0____loss___6.239753____steps___45120: 0%| | 0/10 [01:00<?, ?it/s] cummax 2 2 pos
# epoch___0____loss___6.547979____steps___36240: 0%| | 0/10 [01:00<?, ?it/s] cummax 3 3 no pos
# epoch___0____loss___6.947957____steps___17600: 0%| | 0/10 [01:01<?, ?it/s] src samout
# epoch___0____loss___6.108305____steps___52640: 0%| | 0/10 [02:54<?, ?it/s] src samout
# epoch___0____loss___6.069768____steps___55280: 0%| | 0/10 [03:03<?, ?it/s] src samout
# epoch___0____loss___6.058203____steps___54560: 0%| | 0/10 [01:11<?, ?it/s] current samout
# epoch___0____loss___5.996508____steps___52560: 0%| | 0/10 [01:27<?, ?it/s]
# epoch___0____loss___6.067177____steps___54400: 0%| | 0/10 [01:30<?, ?it/s]
# epoch___0____loss___5.974577____steps___52720: 0%| | 0/10 [01:44<?, ?it/s]
# epoch___0____loss___5.869751____steps___55520: 0%| | 0/10 [01:57<?, ?it/s]
# epoch___0____loss___5.749324____steps___55440: 0%| | 0/10 [02:03<?, ?it/s] maxstate no cat
# epoch___0____loss___5.715099____steps___55440: 0%| | 0/10 [02:26<?, ?it/s] cat
# epoch___0____loss___5.704436____steps___55520: 0%| | 0/10 [02:04<?, ?it/s] x1 +x2+x3
# epoch___0____loss___5.710885____steps___55360: 0%| | 0/10 [02:04<?, ?it/s] x1 +x2+x3 比 cat 牛且减少了参数量
# epoch___0____loss___5.673217____steps___55360: 0%| | 0/10 [02:00<?, ?it/s] out+out1+out2
# epoch___0____loss___5.669157____steps___55360: 0%| | 0/10 [02:13<?, ?it/s]
# epoch___0____loss___5.677723____steps___55360: 0%| | 0/10 [02:42<?, ?it/s]
# epoch___0____loss___5.494996____steps___55360: 0%| | 0/10 [03:43<?, ?it/s]
# epoch___0____loss___5.319009____steps___55280: 0%| | 0/10 [03:42<?, ?it/s] 0.0003
# epoch___0____loss___4.823767____steps___54160: 0%| | 0/10 [03:38<?, ?it/s] 0.0003 结尾 + layer norm
# epoch___0____loss___4.830925____steps___54240: 0%| | 0/10 [03:39<?, ?it/s] 0.0003 都加 + layer norm
# epoch___0____loss___4.843996____steps___56160: 0%| | 0/10 [03:46<?, ?it/s] 0.0003 中间 + relu
# epoch___0____loss___4.821821____steps___55520: 0%| | 0/10 [03:44<?, ?it/s] 0.0003 中间 + gelu
# epoch___0____loss___5.115138____steps___60400: 0%| | 0/10 [04:03<?, ?it/s] 0.0003 中间 + layer norm
这个LLM设计是一个基于PyTorch的序列到序列的模型,它由多个自定义的神经网络模块组成。以下是其主要组件的概述:
- MaxState模块:
- 这个模块似乎是一个自定义的注意力机制,它将输入数据通过多个线性层处理,然后应用一个累积最大(cummax)操作来捕获序列中的全局信息。
- 它使用了多个“头”(heads),每个头有自己的线性层,这类似于Transformer模型中的多头注意力机制。
- 该模块还包括一个LayerNorm层,用于规范化输出。
- FeedForward模块:
- 这是一个前馈神经网络,包含两个线性层和一个ReLU激活函数。
- 还有一个门控机制(gate),通过一个线性层生成门控信号,与第一个线性层的输出相乘。
- SamOut模型:
- 这是一个更大的模型,它包含了多个MaxState和FeedForward模块。
- 它首先通过一个嵌入层(embedding layer)和一个位置编码层(positional encoding layer)处理输入。
- 然后它多次应用MaxState和FeedForward模块,交替使用,以处理序列数据。
- 最后,它通过一个线性层输出最终的预测结果。
- 训练和推断:
- 在模型的最后,有一个循环,用于进行训练和推断。它计算损失并更新模型的权重。
一些其他的要点:
- 在模型的最后,有一个循环,用于进行训练和推断。它计算损失并更新模型的权重。
- 模型使用了dropout来减少过拟合。
- 模型似乎支持在GPU上运行,因为它有一个
.to(device)
调用,其中device
被设置为"cuda"。 - 在模型的主体部分,有多个MaxState和FeedForward模块堆叠在一起,这可能是为了增加模型的表达能力。
- 在代码的注释部分,有关于不同实验设置的注释,比如是否使用LayerNorm、不同的损失值等。
总体来说,这个模型的设计与Transformer模型有相似之处,但也包含了一些独特的元素,如累积最大操作和自定义的注意力机制。