该代码实现了一个自定义的节点编码器 Concat2NodeEncoder
,用于将两个独立的节点编码器的输出结果连接起来。这个类的设计目的是将两个编码器的功能结合起来,以丰富节点特征的表示。通过将 encoder1
和 encoder2
的输出拼接,可以在保留原始特征的同时,加入其他形式的位置编码或其他特征处理。
from lrgb.encoders.composition import Concat2NodeEncoder
import torchclass Concat2NodeEncoder(torch.nn.Module):"""Encoder that concatenates two node encoders."""def __init__(self, enc1_cls, enc2_cls, in_dim, emb_dim, enc2_dim_pe):super().__init__()# PE dims can only be gathered once the cfg is loaded.self.encoder1 = enc1_cls(in_dim=in_dim, emb_dim=emb_dim - enc2_dim_pe)self.encoder2 = enc2_cls(in_dim=in_dim, emb_dim=emb_dim, expand_x=False)def forward(self, x, pestat):x = self.encoder1(x, pestat)x = self.encoder2(x, pestat)return x
1. Concat2NodeEncoder
类定义与初始化
class Concat2NodeEncoder(torch.nn.Module):"""Encoder that concatenates two node encoders."""