您的位置:首页 > 健康 > 养生 > 住建个人证书查询网_乐清站在哪_微信营销方法_辽阳网站seo

住建个人证书查询网_乐清站在哪_微信营销方法_辽阳网站seo

2024/12/27 0:54:35 来源:https://blog.csdn.net/sinat_41942180/article/details/143157530  浏览:    关键词:住建个人证书查询网_乐清站在哪_微信营销方法_辽阳网站seo
住建个人证书查询网_乐清站在哪_微信营销方法_辽阳网站seo

该代码提供了几个常用的工具函数,主要用于设置随机种子、处理数据集的划分与合并、以及获取数据的掩码(mask)。这些函数主要用于 PyTorch 和 PyTorch Geometric 框架中的深度学习训练流程。

from helpers.utils import set_seed

from argparse import Namespace  #类型注解,表示传递的参数集合
import torch
import sys   #输出的进度条
import tqdm  #进度条(显示训练和验证的)
from typing import Tuple, Any  #Tuple用于定义元组,Any用于定义任意类型
from torch_geometric.loader import DataLoader
from torch import Tensor
from torch_geometric.typing import OptTensor
import numpy as npfrom helpers.classes import GumbelArgs, EnvArgs, ActionNetArgs, ActivationType
from helpers.metrics import LossesAndMetrics
from helpers.utils import set_seed
from models.CoGNN import CoGNN
from helpers.dataset_classes.dataset import DatasetBySplitclass Experiment(object):def __init__(self, args: Namespace):super().__init__()for arg in vars(args):value_arg = getattr(args, arg)print(f"{arg}: {value_arg}")self.__setattr__(arg, value_arg)self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')set_seed(seed=self.seed)# parametersself.metric_type = self.dataset.get_metric_type()self.decimal = self.dataset.num_after_decimal()self.task_loss = self.metric_type.get_task_loss()# assertsself.dataset.asserts(args)def run(self) -> Tuple[Tensor, Tensor]:dataset = self.dataset.load(seed=self.seed, pos_enc=self.pos_enc)if self.metric_type.is_multilabel():dataset.data.y = dataset.data.y.to(dtype=torch.float)folds = self.dataset.get_folds(fold=self.fold)# locally used parametersout_dim = self.metric_type.get_out_dim(dataset=dataset)gin_mlp_func = self.dataset.gin_mlp_func()env_act_type = self.dataset.env_activation_type()# named tuplesgumbel_args = GumbelArgs(learn_temp=self.learn_temp, temp_model_type=self.temp_model_type, tau0=self.tau0,temp=self.temp, gin_mlp_func=gin_mlp_func)env_args = \EnvArgs(model_type=self.env_model_type, num_layers=self.env_num_layers, env_dim=self.env_dim,layer_norm=self.layer_norm, skip=self.skip, batch_norm=self.batch_norm, dropout=self.dropout,act_type=env_act_type, metric_type=self.metric_type, in_dim=dataset[0].x.shape[1], out_dim=out_dim,gin_mlp_func=gin_mlp_func, dec_num_layers=self.dec_num_layers, pos_enc=self.pos_enc,dataset_encoders=self.dataset.get_dataset_encoders())action_args = \ActionNetArgs(model_type=self.act_model_type, num_layers=self.act_num_layers,hidden_dim=self.act_dim, dropout=self.dropout, act_type=ActivationType.RELU,env_dim=self.env_dim, gin_mlp_func=gin_mlp_func)# foldsmetrics_list = []edge_ratios_list = []for num_fold in folds:set_seed(seed=self.seed)dataset_by_split = self.dataset.select_fold_and_split(num_fold=num_fold, dataset=dataset)best_losses_n_metrics, edge_ratios =\self.single_fold(dataset_by_split=dataset_by_split, gumbel_args=gumbel_args, env_args=env_args,action_args=action_args, num_fold=num_fold)# print finalprint_str = f'Fold {num_fold}/{len(folds)}'for name in best_losses_n_metrics._fields:print_str += f",{name}={round(getattr(be

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com