您的位置:首页 > 财经 > 产业 > 浏览器2345_泉州seo排名_sem账户托管公司_百度认证考试

浏览器2345_泉州seo排名_sem账户托管公司_百度认证考试

2025/4/14 14:00:00 来源:https://blog.csdn.net/weixin_44253237/article/details/147168147  浏览:    关键词:浏览器2345_泉州seo排名_sem账户托管公司_百度认证考试
浏览器2345_泉州seo排名_sem账户托管公司_百度认证考试

文章目录

  • 前言
  • 一、get_random_problems 函数分析
  • 二、augment_xy_data_by_8_fold 函数分析
  • 代码


前言

该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。

对该代码进行改进
20250412-代码改进-拟蒙特卡洛


一、get_random_problems 函数分析

depot_xy = torch.rand(size=(batch_size, 1, 2))
  • 生成仓库坐标:
    • 生成形状为(batch_size, 1, 2) 的随机张量,表示每个批次中仓库的二维坐标(范围 [0,1))。
node_xy = torch.rand(size=(batch_size, problem_size, 2))
  • 生成节点坐标:
    • 生成形状为 (batch_size, problem_size, 2) 的随机张量,表示每个批次中所有节点的二维坐标。
if problem_size == 20:demand_scaler = 30
elif problem_size == 50:demand_scaler = 40
elif problem_size == 100:demand_scaler = 50
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
  • 生成节点需求:
    • 根据 problem_size 选择缩放因子 demand_scaler
    • 生成 1~9 的整数需求,并缩放到 [1/50, 9/50] 等区间,确保需求值为浮点数。

二、augment_xy_data_by_8_fold 函数分析

功能:通过8种几何变换对坐标数据进行增强,扩充数据集。

x = xy_data[:, :, [0]]  # 提取x坐标
y = xy_data[:, :, [1]]  # 提取y坐标
  • 拆分坐标:
    • 从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。
dat1 = torch.cat((x, y), dim=2)          # 原始坐标
dat2 = torch.cat((1 - x, y), dim=2)      # x轴镜像
dat3 = torch.cat((x, 1 - y), dim=2)      # y轴镜像
dat4 = torch.cat((1 - x, 1 - y), dim=2)  # x+y轴镜像
dat5 = torch.cat((y, x), dim=2)          # 转置坐标
dat6 = torch.cat((1 - y, x), dim=2)      # 转置后x轴镜像
dat7 = torch.cat((y, 1 - x), dim=2)      # 转置后y轴镜像
dat8 = torch.cat((1 - y, 1 - x), dim=2)  # 转置后x+y轴镜像
  • 生成8种变换:
    • 对坐标进行镜像翻转和转置操作,生成8种变体。
aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
  • 合并增强数据:
  • 将8种变换后的数据沿批次维度拼接,最终形状为 (8*batch, N, 2)

代码


import torch
import numpy as npdef get_random_problems(batch_size, problem_size):depot_xy = torch.rand(size=(batch_size, 1, 2))# shape: (batch, 1, 2)node_xy = torch.rand(size=(batch_size, problem_size, 2))# shape: (batch, problem, 2)if problem_size == 20:demand_scaler = 30elif problem_size == 50:demand_scaler = 40elif problem_size == 100:demand_scaler = 50else:raise NotImplementedErrornode_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)# shape: (batch, problem)return depot_xy, node_xy, node_demanddef augment_xy_data_by_8_fold(xy_data):# xy_data.shape: (batch, N, 2)x = xy_data[:, :, [0]]y = xy_data[:, :, [1]]# x,y shape: (batch, N, 1)dat1 = torch.cat((x, y), dim=2)dat2 = torch.cat((1 - x, y), dim=2)dat3 = torch.cat((x, 1 - y), dim=2)dat4 = torch.cat((1 - x, 1 - y), dim=2)dat5 = torch.cat((y, x), dim=2)dat6 = torch.cat((1 - y, x), dim=2)dat7 = torch.cat((y, 1 - x), dim=2)dat8 = torch.cat((1 - y, 1 - x), dim=2)aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)# shape: (8*batch, N, 2)return aug_xy_data

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com