您的位置:首页 > 汽车 > 新车 > 网页链接怎么复制_建设通官网通_杭州seo软件_推广赚佣金项目

网页链接怎么复制_建设通官网通_杭州seo软件_推广赚佣金项目

2024/11/19 14:47:31 来源:https://blog.csdn.net/sinat_41942180/article/details/143024418  浏览:    关键词:网页链接怎么复制_建设通官网通_杭州seo软件_推广赚佣金项目
网页链接怎么复制_建设通官网通_杭州seo软件_推广赚佣金项目

该代码实现了一系列类与方法,主要用于图神经网络(GNN)模型的编码、激活函数、池化操作等。EnvArgsActionNetArgs 类用于根据配置参数生成网络结构,并通过 ModelTypeActivationTypePool 等控制模型的组件与行为。代码的主要目的是构建可配置的图神经网络模型,并在这些模型中实现了不同的特征编码器、激活函数、池化策略等。

from helpers.classes import ActivationType, Pool, ModelType

from enum import Enum, auto
from torch.nn import Linear, ModuleList, Module, Dropout, ReLU, GELU, Sequential
from torch import Tensor
from typing import NamedTuple, Any, Callable
import torch.nn.functional as F
from torch_geometric.nn.pool import global_mean_pool, global_add_poolfrom helpers.metrics import MetricType
from helpers.model import ModelType
from helpers.encoders import DataSetEncoders, PosEncoder
from lrgb.encoders.composition import Concat2NodeEncoderclass ActivationType(Enum):"""an object for the different activation types"""RELU = auto()GELU = auto()@staticmethoddef from_string(s: str):try:return ActivationType[s]except KeyError:raise ValueError()def get(self):if self is ActivationType.RELU:return F.reluelif self is ActivationType.GELU:return F.geluelse:raise ValueError(f'ActivationType {self.name} not supported')def nn(self) -> Module:if self is ActivationType.RELU:return ReLU()elif self is ActivationType.GELU:return GELU()else:raise ValueError(f'ActivationType {self.name} not supported')class GumbelArgs(NamedTuple):learn_temp: booltemp_model_type: ModelTypetau0: floattemp: floatgin_mlp_func: Callableclass Pool(Enum):"""an object for the different activation types"""NONE = auto()MEAN = auto()SUM = auto()@staticmethoddef from_string(s: str):try:return Pool[s]except KeyError:raise ValueError()def get(self):if self is Pool.MEAN:return global_mean_poolelif self is Pool.SUM:return global_add_poolelif self is Pool.NONE:return BatchIdentity()else:raise ValueError(f'Pool {self.name} not supported')class EnvArgs(NamedTuple):model_type: ModelTypenum_layers: intenv_dim: intlayer_norm: boolskip: boolbatch_norm: booldropout: floatact_type: ActivationTypedec_num_layers: intpos_enc: PosEncoderdataset_encoders: DataSetEncodersmetric_type: MetricTypein_dim: intout_dim: intgin_mlp_func: Callabledef load_net(self) -> ModuleList:if self.pos_enc is PosEncoder.NONE:enc_list = [self.dataset_encoders.node_encoder(in_dim=self.in_dim, emb_dim=self.env_dim)]else:if self.dataset_encoders is DataSetEncoders.NONE:enc_list = [self.pos_enc.get(in_dim=self.in_dim, emb_dim=self.env_dim)]else:enc_list = [Concat2NodeEncoder(enc1_cls=self.dataset_encoders.node_encoder,enc2_cls=self.pos_enc.get,in_dim=self.in_dim, emb_dim=self.env_dim,enc2_dim_pe=self.pos_enc.DIM_PE())]component_list =\self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.env_dim,  out_dim=self.env_dim,num_layers=self.num_layers, bias=True, edges_required=True,gin_mlp_func=self.gin_mlp_func)if self.dec_num_layers > 1:mlp_list = (self.dec_num_layers - 1) * [Linear(self.env_dim, self.env_dim),Dropout(self.dropout), self.act_type.nn()]mlp_list = mlp_list + [Linear(self.env_dim, self.out_dim)]dec_list = [Sequential(*mlp_list)]else:dec_list = [Linear(self.env_dim, self.out_dim)]return ModuleList(enc_list + component_list + dec_list)class ActionNetArgs(NamedTuple):model_type: ModelTypenum_layers: inthidden_dim: intdropout: floatact_type: ActivationTypeenv_dim: intgin_mlp_func: Callabledef load_net(self) -> ModuleList:net = self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.hidden_dim, out_dim=2,num_layers=self.num_layers, bias=True, edges_required=False,gin_mlp_func=self.gin_mlp_func)return ModuleList(net)class BatchIdentity(Module):def __init__(self, *args: Any, **kwargs: Any) -> None:super().__init__()def forward(self, x: Tensor, batch: Tensor) -> Tensor:return x

这里包括了

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com