博客 ray.rllib-入门实践-11: 自定义模型/网络 中介绍了如何自定义一个model并注册到ray中去使用。除了直接向ray注册model之外,还有另外一种使用自定义的 model 的方式:在自定义的policy时,修改policy默认的model为自定义的model。 通过这种方式,在使用自定义的policy时,就默认使用了自定义的model, 不用再额外向ray注册,可以默认自定义的policy自动向ray注册了自定义的model。 如何自定义policy并使用见博客。
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
示例代码如下:
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from gymnasium import spaces
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import ray
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
import torch
from typing import Dict, List, Type, Union
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch## 1. 自定义模型 model
class CustomTorchModel(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict, ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))self.activation = nn.ReLU()## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNet# self.actorNet = nn.Linear(128, action_dim)self.actorNet = nn.Linear(128, num_outputs) # 最后一层的输出要设置为 num_outputs,action_dim有时会报错。## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)action_logits = self.activation(action_logits)## 测试是否使用了自己的 model print(f"xxxxxxxxxxxxxxxxx 使用了自定义的 model: CustomTorchModel")return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value ## 2. 自定义策略 policy # 重构 model 和 loss 函数
class MY_PPOTorchPolicy(PPOTorchPolicy):"""PyTorch policy class used with PPO."""def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): PPOTorchPolicy.__init__(self,observation_space,action_space,config)## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config# 通过修改自定义policy的默认model的方式,使用自定义的model.# 当rllib在使用这个自定义的policy时, 可以默认该 policy 向 ray 注册了这个自定义的 model. def make_model_and_action_dist(self):dist_class,logit_dim = ModelCatalog.get_action_dist(self.action_space,self.config['model'],framework=self.framework)model = CustomTorchModel(obs_space=self.observation_space,action_space=self.action_space,num_outputs=logit_dim,model_config=self.config['model'],name='My_CustomTorchModel')return model, dist_class@override(PPOTorchPolicy) def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):## 原始损失original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。## 新增自定义损失,这里以正则化损失作为示例addiontial_loss = torch.tensor(0.0) ## 自己定义的lossaddiontial_loss = torch.tensor(0.)for param in model.parameters():addiontial_loss += torch.norm(param)## 得到更新后的损失new_loss = original_loss + 0.01 * addiontial_loss## 测试是否使用了自己的policy print(f"xxxxxxxxxxxxxxxxx 使用了自定义的policy: MY_PPOTorchPolicy")return new_loss## 3. 把自定义的policy封装为算法. 训练和配置的都是算法。
class MY_PPO(PPO):## 重写 PPO.get_default_policy_class 函数, 使其返回自定义的policy def get_default_policy_class(self, config):return MY_PPOTorchPolicyif __name__ == "__main__":## 测试执行自定义的 model and policy ray.init()model_config_dict = {}config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法config = config.environment("CartPole-v1")config = config.rollouts(num_rollout_workers=2)config = config.framework(framework="torch")config = config.training(model=model_config_dict) ## 增加新model配置参数, 保留其他未重构的默认参数。algo = config.build()for i in range(3):result = algo.train()print(f"itear_{i}")print("==训练完毕==")