该代码实现了一系列类与方法,主要用于图神经网络(GNN)模型的编码、激活函数、池化操作等。EnvArgs
和 ActionNetArgs
类用于根据配置参数生成网络结构,并通过 ModelType
、ActivationType
、Pool
等控制模型的组件与行为。代码的主要目的是构建可配置的图神经网络模型,并在这些模型中实现了不同的特征编码器、激活函数、池化策略等。
from helpers.classes import ActivationType, Pool, ModelType
from enum import Enum, auto
from torch.nn import Linear, ModuleList, Module, Dropout, ReLU, GELU, Sequential
from torch import Tensor
from typing import NamedTuple, Any, Callable
import torch.nn.functional as F
from torch_geometric.nn.pool import global_mean_pool, global_add_poolfrom helpers.metrics import MetricType
from helpers.model import ModelType
from helpers.encoders import DataSetEncoders, PosEncoder
from lrgb.encoders.composition import Concat2NodeEncoderclass ActivationType(Enum):"""an object for the different activation types"""RELU = auto()GELU = auto()@staticmethoddef from_string(s: str):try:return ActivationType[s]except KeyError:raise ValueError()def get(self):if self is ActivationType.RELU:return F.reluelif self is ActivationType.GELU:return F.geluelse:raise ValueError(f'ActivationType {self.name} not supported')def nn(self) -> Module:if self is ActivationType.RELU:return ReLU()elif self is ActivationType.GELU:return GELU()else:raise ValueError(f'ActivationType {self.name} not supported')class GumbelArgs(NamedTuple):learn_temp: booltemp_model_type: ModelTypetau0: floattemp: floatgin_mlp_func: Callableclass Pool(Enum):"""an object for the different activation types"""NONE = auto()MEAN = auto()SUM = auto()@staticmethoddef from_string(s: str):try:return Pool[s]except KeyError:raise ValueError()def get(self):if self is Pool.MEAN:return global_mean_poolelif self is Pool.SUM:return global_add_poolelif self is Pool.NONE:return BatchIdentity()else:raise ValueError(f'Pool {self.name} not supported')class EnvArgs(NamedTuple):model_type: ModelTypenum_layers: intenv_dim: intlayer_norm: boolskip: boolbatch_norm: booldropout: floatact_type: ActivationTypedec_num_layers: intpos_enc: PosEncoderdataset_encoders: DataSetEncodersmetric_type: MetricTypein_dim: intout_dim: intgin_mlp_func: Callabledef load_net(self) -> ModuleList:if self.pos_enc is PosEncoder.NONE:enc_list = [self.dataset_encoders.node_encoder(in_dim=self.in_dim, emb_dim=self.env_dim)]else:if self.dataset_encoders is DataSetEncoders.NONE:enc_list = [self.pos_enc.get(in_dim=self.in_dim, emb_dim=self.env_dim)]else:enc_list = [Concat2NodeEncoder(enc1_cls=self.dataset_encoders.node_encoder,enc2_cls=self.pos_enc.get,in_dim=self.in_dim, emb_dim=self.env_dim,enc2_dim_pe=self.pos_enc.DIM_PE())]component_list =\self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.env_dim, out_dim=self.env_dim,num_layers=self.num_layers, bias=True, edges_required=True,gin_mlp_func=self.gin_mlp_func)if self.dec_num_layers > 1:mlp_list = (self.dec_num_layers - 1) * [Linear(self.env_dim, self.env_dim),Dropout(self.dropout), self.act_type.nn()]mlp_list = mlp_list + [Linear(self.env_dim, self.out_dim)]dec_list = [Sequential(*mlp_list)]else:dec_list = [Linear(self.env_dim, self.out_dim)]return ModuleList(enc_list + component_list + dec_list)class ActionNetArgs(NamedTuple):model_type: ModelTypenum_layers: inthidden_dim: intdropout: floatact_type: ActivationTypeenv_dim: intgin_mlp_func: Callabledef load_net(self) -> ModuleList:net = self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.hidden_dim, out_dim=2,num_layers=self.num_layers, bias=True, edges_required=False,gin_mlp_func=self.gin_mlp_func)return ModuleList(net)class BatchIdentity(Module):def __init__(self, *args: Any, **kwargs: Any) -> None:super().__init__()def forward(self, x: Tensor, batch: Tensor) -> Tensor:return x
这里包括了