nn.Embedding
是 PyTorch 中用于将离散的整数索引(代表类别或符号)转换为连续向量表示的层。这个嵌入层特别适合用于自然语言处理、序列数据、推荐系统、以及生物信息学中的离散符号编码(如氨基酸序列等)等任务。
一、nn.Embedding
的定义和参数
nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
参数说明
num_embeddings
(int):嵌入矩阵的行数,即可表示的类别总数。例如,对于 21 个氨基酸和一个 gap 符号,num_embeddings=21+1=22
。embedding_dim
(int):嵌入向量的维度大小。比如embedding_dim=64
表示每个类别会映射成一个 64 维的向量。padding_idx
(int, 可选):指定填充索引(padding index),填充索引所对应的嵌入向量始终输出零向量。用于处理变长序列。max_norm
(float, 可选):限制嵌入向量的最大范数,超过该值会被缩放到这个范围。