1. 使用注意力机制的seq2seq
key value等价 是一个东西 第i个词rnn的输出
根据加权的不同,解码器前面用编码器前面的输出,到后面用后面的输出。
2. code
核心代码: context 怎么算
embedding没变,Decoder加了attention层
class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,# num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 输出X的形状为(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形状为(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# context的形状为(batch_size,1,num_hiddens)# query是一直在变的, enc_outputs, enc_outputs解码器每个词的输出,不变。context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)# 全连接层变换后,outputs的形状为# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights
3. QA
7 不会看别的句子的输出,当前句子对应英语句子的输出
8 是decode输入最后一个词的状态。padd或者句子结束的符号
9 可以 bert就是只有encoder,纯attention,没有rnn
10 句子统一长度,加了padding,valid_lens记录原始数据句子的长度
11 不类似 。 束搜索是最后一层全连接层怎么输出,注意力机制是在rnn上的使用。
13 一个子图就是一个key-value, 自注意力机制。