您的位置:首页 > 游戏 > 游戏 > 注意力机制

注意力机制

2024/10/6 8:00:56 来源:https://blog.csdn.net/m0_53881899/article/details/140873876  浏览:    关键词:注意力机制

目录

  • 一、注意力机制的提出:
    • 1.传统神经网络存在的缺点:
    • 2.注意力机制在传统神经网络上的改进:
  • 二、QKV是什么:
  • 三、注意力机制:
    • 1.注意力机制的直观理解:
    • 2.注意力机制的一般定义:
  • 四、注意力分数函数α:
    • 1.加性注意力(Additive attention)
      • 1.1定义:
      • 1.2代码:
    • 2.缩放点积注意力(Scaled Dot-Product Attention)
      • 2.1定义:
      • 2.2代码:
  • 五、如何将Attention用于Seq2seq:
    • 1.模型架构:
    • 2.代码:

一、注意力机制的提出:

1.传统神经网络存在的缺点:

  • 传统的CNN在处理序列信息时通过卷积操作来提取特征,但是不会考虑序列内部的相关性
  • RNN通过递归结构计算隐藏状态ht来提取序列内部的相关性信息,解决了CNN处理序列信息的缺陷,但是ht中记录了所有token的序列信息,并且默认所有token之间的相关性相等,对于当前时间步输入token,ht中的部分序列信息跟当前输入token没关系,属于噪音,直接使用ht中所有序列信息进行计算肯定会损失精度,但是CNN并不具有从ht中选择序列信息的能力;
  • GRU和LSTM通过门的概念来动态选择ht中的序列信息,每个时间步进行计算时通过门的控制会考虑序列信息ht中与当前token的相关性,从而选择相关性高的序列信息参与计算,减少噪音等不相关性信息的影响。但是LSTM和GRU是顺序处理的,每一个时间步都依赖于前一个时间步的输出,因此不能并行处理整个序列,计算效率较低;

2.注意力机制在传统神经网络上的改进:

  • Attnetion应运而生,目的是更好地捕捉和利用序列内元素之间的相关性,而不是简单假设所有token之间都具有相等的相关性,它通过动态计算注意力权重,使得模型能够灵活地选择与当前输入最相关的部分并且可以并行计算序列内部的相关性,不仅能计算当前token与之前序列的相关性,还能计算与后续序列的相关性,解决了GRU和LSTM存在的问题。

二、QKV是什么:

QKV是注意力机制中的重要数据量,记录了序列的序列信息,具体表示什么序列信息需要根据不同应用场景合理的选择。

举个例子,对于英译德的场景来说:
在这里插入图片描述

kv是编码器对每个英文词的rnn输出,包含了该词及其之前的序列信息。
如果有“hello world .”三个词,则每个词对应一个kv,有k1、k2、k3、v1、v2、v3,并且k1=v1、k2=v2、k3=v3。其中k2和v2就表示“hello world”这两个词的序列信息h2。

当前时间步的q是解码器上一个时间步德语词的rnn预测输出,记录了上一个德语词的翻译结果,通过上一个德语词的翻译结果作为q去查与该德语词上下文相关的英语词kv,作为当前待翻译德语词的序列信息(编码器Attention输出)。
如果“hello world .”对应的翻译为“bonjour le monde .”,那么在翻译“le”这个词的时候q就是上一个德语词bonjour,通过使用q去查kv可以找到与德语词bonjour相关性更高的序列信息k2v2,即“hello world”。(为什么k2v2和bonjour的相关性最高?因为bonjour是hello的翻译结果,与bonjour相关性最高的也就是它的上下文,即hello world。如果只是k1v1的话只有hello就缺少下文的相关性,如果是k3v3的话是hello world . 多了个.,这个.对hello没有很高的相关性)(为什么不用当前待翻译德语词去找与他相关的英语序列信息而是用与上一个德语词相关的英语序列信息?因为当前待翻译的词我们预先并不知道,例如使用“bonjour”翻译“le”这个词的时候我们不知道要翻译的是“le”,也就没法没法提前对“le”这个词找相关的序列信息,只能用与他相邻的上个词的序列信息)

Seq2seq网络架构执行英译德的过程如下:
在这里插入图片描述

三、注意力机制:

1.注意力机制的直观理解:

这里假设序列的QKV都是一个数,注意力机制计算当前序列中token’a’与其余token’b’、‘c’、‘d’、‘e’、'f’的相关性过程如下:
在这里插入图片描述
对于选择的token Query,计算其与其余所有token(每个token对应一个key-value键值对)的相关性计算过程:首先将 query 和每一个 key 通过注意力分数函数 a 和 softmax 运算得到相关性权重(与 key 对应的值的概率分布),将这些注意力权重再与对应的 value 进行加权求和,最终就得到了输出。

2.注意力机制的一般定义:

一个query与其余所有key-value计算相关性的过程如下,注意力机制是所有q与所有k计算来获取所有q的注意力结果。
在这里插入图片描述

四、注意力分数函数α:

这里以一个q和一个k的注意力分数计算为例简述打分函数定义。

1.加性注意力(Additive attention)

1.1定义:

当 query 和 key 是不同长度的矢量时,可以使用加性注意力作为注意力分数。。
在这里插入图片描述

  • h:超参数
  • 可学习参数:
    • Wk:key关于h的权重参数(将 key 向量的长度从 k 转化为 h)
    • Wq:query关于h的权重参数(将 query 向量的长度从 q 转化为 h)
  • tanh:激活函数
  • 这里tanh(wkk+wqq)可以通过一个全连接隐藏层实现,其中隐藏层神经元个数为h,没有偏置值。(仔细想想不难理解)

1.2代码:

class AdditiveAttention(nn.Module):"""加性注意力"""def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)# 全连接层隐藏层self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)features = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)

2.缩放点积注意力(Scaled Dot-Product Attention)

2.1定义:

使用缩放点积可以得到计算效率更高i的评分函数,但是缩放点积操作要求 query 和 key 具有相同的长度
在这里插入图片描述

  • 这里不需要学习任何东西,直接利用 <q,ki> 将 q 和 ki 做内积然后除以根号d(除以根号 d 的目的是为了降低对 ki 的长度的敏感度)

2.2代码:

class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)

五、如何将Attention用于Seq2seq:

1.模型架构:

编码层不变,解码层加入Attention机制。
在这里插入图片描述

2.代码:

import torch
from torch import nn
from d2l import torch as d2l# 带有注意力机制的解码器
class AttentionDecoder(d2l.Decoder):def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError# 解码器架构
class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)# 加性注意力分数函数 self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens,num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):enc_outputs, hidden_state, enc_valid_lens = stateX = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []# 对于每个待翻译的token:xfor x in X:# q为上一个时间步的RNN隐藏层输出hidden_state[-1]query = torch.unsqueeze(hidden_state[-1], dim=1)# enc_outputs维度为batch_size,num_steps,num_hiddens,表示每个时间步的h,一共num_steps个时间步(token)# enc_valid_lens是长为batch_size的一维向量,记录了每个句子的有效长度context = self.attention(query, enc_outputs, enc_outputs,enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights# 训练    
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens,num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens,num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)  # 预测
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ',f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com