第 18 章 离线强化学习(gym版本 >= 0.26)
- 摘要
- SAC 算法部分
- CQL 算法
- CQL 总结与大函数意义
- CQL 总结
- CQL 类详细分析
摘要
本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!
对应动手学强化学习——离线强化学习
SAC 算法部分
代码讲解对应基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)
为了生成数据集,在倒立摆环境中从零开始训练一个在线 SAC 智能体,直到算法达到收敛效果,把训练过程中智能体采集的所有轨迹保存下来作为数据集。这样,数据集中既包含训练初期较差策略的采样,又包含训练后期较好策略的采样,是一个混合数据集。下面给出生成数据集的代码,SAC 部分直接使用 14.5 节中的代码,因此不再详细解释。——18.4 CQL 代码实践
import numpy as np
import gym
from tqdm import tqdm
import random
import rl_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt# 传统的SAC算法
class PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))mu = self.fc_mu(x)std = F.softplus(self.fc_std(x))dist = Normal(mu, std)normal_sample = dist.rsample() # rsample()是重参数化采样log_prob = dist.log_prob(normal_sample)action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class SACContinuous:''' 处理连续动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,action_bound).to(device) # 策略网络self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device) # 第一个Q网络self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device) # 第二个Q网络self.target_critic_1 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device) # 第一个目标Q网络self.target_critic_2 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device) # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True #对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.actor(state)[0]return [action.item()]def calc_target(self, rewards, next_states, dones): # 计算目标Q值next_actions, log_prob = self.actor(next_states)entropy = -log_probq1_value = self.target_critic_1(next_states, next_actions)q2_value = self.target_critic_2(next_states, next_actions)next_value = torch.min(q1_value,q2_value) + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)rewards = (rewards + 8.0) / 8.0 # 对倒立摆环境的奖励进行重塑# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络new_actions, log_prob = self.actor(states)entropy = -log_probq1_value = self.critic_1(states, new_actions)q2_value = self.critic_2(states, new_actions)actor_loss = torch.mean(-self.log_alpha.exp() * entropy -torch.min(q1_value, q2_value))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0] # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005 # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau,gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()
CQL 算法
CQL 总结与大函数意义
CQL(Conservative Q-Learning) 类的整体设计在 SAC 算法基础上增加了保守性正则项,用于减少 Q 函数过估计,改善离线 RL 表现。
-
构造函数
- 意义:初始化 SAC 中的 actor、critic(两个及对应目标网络)、以及温度参数和优化器,并传入 CQL 特有的超参数 beta(正则系数)和 num_random(动作采样数)。
- 输入:状态、动作维度、各个学习率、目标熵、tau、gamma、设备、beta、num_random。
- 输出:构造好的 CQL 实例,准备进行更新。
-
take_action
- 意义:给定状态,利用 actor 网络采样动作,用于与环境交互。
- 输入:单个状态(例如 [0.1, 0.2, -0.1])。
- 输出:对应动作(例如 [0.8])。
-
soft_update
- 意义:平滑更新目标 Q 网络参数,保证训练稳定性。
- 输入:当前 Q 网络和对应目标网络。
- 输出:目标网络参数更新为旧值与当前值的线性组合。
-
update
- 意义:使用一批真实环境转移数据更新 actor、critic 网络和温度参数 (\alpha),同时在 SAC 基础上加入 CQL 正则项
L CQL = L critic + β ( log ∑ exp ( Q ( s , a ′ ) − log π ref ( a ′ ) ) − E ( s , a ) ∼ D [ Q ( s , a ) ] ) L_{\text{CQL}} = L_{\text{critic}} + \beta \Big(\log\sum \exp(Q(s,a') - \log \pi_{\text{ref}}(a')) - \mathbb{E}_{(s,a) \sim D}[Q(s,a)]\Big) LCQL=Lcritic+β(log∑exp(Q(s,a′)−logπref(a′))−E(s,a)∼D[Q(s,a)])
其中通过对随机动作、策略动作和下一动作进行采样,计算 logsumexp 的项。 - 步骤:
- 计算 TD 目标 y = r + γ ( min ( Q 1 ′ , Q 2 ′ ) + α ⋅ entropy ) y = r + \gamma ( \min(Q_1', Q_2') + \alpha \cdot \text{entropy}) y=r+γ(min(Q1′,Q2′)+α⋅entropy)。
- 分别计算 critic_1_loss 与 critic_2_loss(均方误差损失)。
- 额外采样一批随机动作(均匀分布),以及策略生成的当前和下一动作,对各个 Q 网络的输出进行 logsumexp 操作,形成 CQL 正则项;
- 总的 critic 损失 = SAC critic loss + β \beta β×(CQL 正则项差值)。
- 分别更新 critic_1 与 critic_2;
- 更新 actor,使其最大化 min ( Q 1 , Q 2 ) − α log π ( a ∣ s ) \min(Q_1, Q_2) - \alpha \log\pi(a|s) min(Q1,Q2)−αlogπ(a∣s);
- 更新 α \alpha α 使策略熵接近目标熵;
- 最后对目标网络进行软更新。
- 输入:
- transition_dict 包含 ‘states’, ‘actions’, ‘rewards’, ‘next_states’, ‘dones’。
- 输出:
- 模型参数更新后,策略和 Q 网络得到改进。
- 意义:使用一批真实环境转移数据更新 actor、critic 网络和温度参数 (\alpha),同时在 SAC 基础上加入 CQL 正则项
CQL 总结
CQL 类在 SAC 框架下增加了保守性正则项,以降低 Q 函数对未见动作过高的估计,从而实现更为保守的 Q 学习。主要流程包括:
- 网络初始化:构造策略网络、两个 Q 网络、对应目标网络,温度参数 (\alpha) 及其优化器,以及 SAC 固有参数(gamma、tau、目标熵)和 CQL 超参数(beta、num_random)。
- 动作采样(take_action):给定状态通过 actor 网络采样动作。
- 软更新(soft_update):平滑更新目标网络参数。
- 更新过程(update):
- 从 transition_dict 中读取数据并转换为张量;
- 计算 SAC 的 TD 目标和 critic 损失;
- 额外采样随机动作及策略动作,并计算 CQL 正则项(通过 logsumexp);
- 总的 critic 损失 = SAC critic loss + beta*(正则项差值);
- 更新 critic 网络、策略网络及温度参数 (\alpha);
- 最后软更新目标网络参数。
CQL 类详细分析
class CQL:''' CQL算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device, beta, num_random):"""定义 CQL 类构造函数,接收状态、隐藏、动作维度、动作范围,以及各优化器学习率、目标熵、tau、gamma、设备、CQL正则系数 beta 和随机采样数 num_random。"""'''创建策略网络(actor),用于输出动作分布并采样连续动作。'''self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)'''创建两个 Q 网络,分别用于评估 (state,action) 对的价值,帮助降低过估计偏差。'''self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)'''创建目标 Q 网络,用于计算 TD 目标,使得训练更稳定。'''self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)'''复制 critic 网络参数到目标网络,保证初始时一致。'''self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())'''为策略网络分配 Adam 优化器,学习率 actor_lr。'''self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)'''分别为两个 Q 网络创建 Adam 优化器。'''self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)'''初始化温度参数的对数,log_alpha = log(0.01) ≈ -4.6052。'''self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)'''使 log_alpha 可求梯度,从而在训练过程中自动调整 alpha=exp({log_alpha})。'''self.log_alpha.requires_grad = True #对alpha求梯度'''为 log_alpha 创建 Adam 优化器,学习率 alpha_lr。'''self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)'''保存目标熵,用于温度调整。'''self.target_entropy = target_entropy # 目标熵的大小'''保存折扣因子 gamma。'''self.gamma = gamma'''保存软更新系数 tau,用于目标网络更新。'''self.tau = tau'''保存 CQL 损失函数中的系数 beta,用于权衡 Q 网络额外正则项。'''self.beta = beta # CQL损失函数中的系数'''保存 CQL 中在计算正则项时所采样的随机动作数,用于估计动作分布上界。'''self.num_random = num_random # CQL中的动作采样数def take_action(self, state):"""定义策略执行接口,给定单个状态,输出对应动作。"""if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)# action = self.actor(state)[0]# log_prob = self.actor(state)[1]action, log_prob = self.actor(state)return [action.item()]def soft_update(self, net, target_net):"""对目标网络进行软更新,即用当前网络参数更新目标网络参数。"""for param_target, param in zip(target_net.parameters(), net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)def update(self, transition_dict):"""定义策略与 Q 网络的更新过程,使用从环境采集的经验数据更新所有网络参数,同时计算 CQL 的额外正则项。"""'''从transition_dict中提取数据'''states = torch.tensor(transition_dict['states'], dtype=torch.float).to(device)actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(device)'''对奖励进行归一化处理,这里将倒立摆环境的奖励平移与缩放,使得奖励范围更加稳定。'''rewards = (rewards + 8.0) / 8.0 # 对倒立摆环境的奖励进行重塑'''对所有下一状态通过 actor 得到下一动作及其对数概率。'''next_actions, log_prob = self.actor(next_states)'''计算熵项,即熵 = - log_prob。'''entropy = -log_prob'''用目标 Q 网络计算下一状态下的 Q 值估计。'''q1_value = self.target_critic_1(next_states, next_actions)q2_value = self.target_critic_2(next_states, next_actions)'''计算下一时刻的价值估计,选择较小的 Q 值(双 Q 机制),并加入熵正则项 α⋅entropy。'''next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy'''计算 TD 目标,若 done 为 1(终止状态)则不加折扣。td_target = 当前的即时奖励 + gamma*下一阶段的奖励'''td_target = rewards + self.gamma * next_value * (1 - dones)'''分别计算两个 Q 网络的均方误差(MSE)损失,相对于 td_target(detach防止梯度流向目标)。'''critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))# 以上与SAC相同,以下Q网络更新是CQL的额外部分'''取当前批次样本数量。'''batch_size = states.shape[0]'''作用:- 生成一组随机均匀分布的动作,形状为 (batch_size*num_random, action_dim),范围在 [-1,1]。- 用于 CQL 正则项计算,作为额外的动作样本。数值例子:- 若 batch_size=64, num_random=10, action_dim=1,则生成形状为 (640,1) 的随机动作,如 [[0.23], [-0.45], …]。'''random_unif_actions = torch.rand([batch_size * self.num_random, actions.shape[-1]], dtype=torch.float).uniform_(-1, 1).to(device)'''作用:计算均匀分布的对数概率密度。- 由于在连续区间 [−1,1] 中,均匀密度为 1/2 对于每个动作维度,因此对数概率为 log(0.5);- 对于 action_dim 个维度,总和为 log(0.5^action_dim)。'''random_unif_log_pi = np.log(0.5**next_actions.shape[-1])# random_unif_log_pi = np.log(0.5) * next_actions.shape[-1]'''作用:增加数据集- 将 states 增加一个新维度,重复 num_random 次,使得每个样本重复 num_random 次,最后 reshape 成 (batch_size*num_random, state_dim)。数值例子:- 若 states shape=(64,3), unsqueeze后变为 (64,1,3), repeat 后变为 (64,10,3), view 后为 (640,3).'''tmp_states = states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, states.shape[-1])tmp_next_states = next_states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, next_states.shape[-1])'''用策略网络对 tmp_states(重复后的真实状态)采样动作,得到随机采样的当前动作及其对数概率。'''random_curr_actions, random_curr_log_pi = self.actor(tmp_states)'''同理,对 tmp_next_states 采样下一步动作和对应对数概率。'''random_next_actions, random_next_log_pi = self.actor(tmp_next_states)'''用 critic 网络计算对于 tmp_states 与随机均匀动作 random_unif_actions 的 Q 值,之后 reshape 成 (batch_size, num_random, 1)。'''q1_unif = self.critic_1(tmp_states, random_unif_actions).view(-1, self.num_random, 1)q2_unif = self.critic_2(tmp_states, random_unif_actions).view(-1, self.num_random, 1)'''用 critic 网络计算对于 tmp_states 与随机当前动作的 Q 值,并 reshape。'''q1_curr = self.critic_1(tmp_states, random_curr_actions).view(-1, self.num_random, 1)q2_curr = self.critic_2(tmp_states, random_curr_actions).view(-1, self.num_random, 1)'''用 critic 网络计算对于 tmp_states 与随机下一动作的 Q 值,并 reshape。'''q1_next = self.critic_1(tmp_states, random_next_actions).view(-1, self.num_random, 1)q2_next = self.critic_2(tmp_states, random_next_actions).view(-1, self.num_random, 1)'''作用:- 将三部分 Q 值进行拼接:1. 对于随机均匀采样动作:减去其对数概率(固定值);2. 对于策略采样的当前动作:减去对应对数概率(detach 后避免梯度传递);3. 对于策略采样的下一动作:同理。- 拼接维度为第 1 维(即动作样本维度)。数值例子:- 假设每个部分形状为 (64,10,1),拼接后 q1_cat 形状为 (64,30,1)。在 Conservative Q-Learning (CQL) 中,为了防止 Q 函数在离线数据之外的动作上过高估计,会在 critic 损失中增加一个正则项。正则项的思想是“保守地”估计 Q 值,使得对于未见过的动作,Q 值不至于过高。具体来说,CQL 的正则项大致形式为:Penalty=logE_{a∼μ}[exp(Q(s,a)−logμ(a))]−E_{(s,a)∼D}[Q(s,a)]'''q1_cat = torch.cat([q1_unif - random_unif_log_pi,q1_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),q1_next - random_next_log_pi.detach().view(-1, self.num_random, 1)], dim=1)q2_cat = torch.cat([q2_unif - random_unif_log_pi,q2_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),q2_next - random_next_log_pi.detach().view(-1, self.num_random, 1)], dim=1)'''作用:- 对 q1_cat 和 q2_cat 在动作维度上做 logsumexp 操作,再取平均,得到一个标量,表示对所有随机动作样本的软最大值。- logsumexp 是一种平滑的最大值函数,计算公式:log∑_iexp(xi)'''qf1_loss_1 = torch.logsumexp(q1_cat, dim=1).mean()qf2_loss_1 = torch.logsumexp(q2_cat, dim=1).mean()'''分别计算 critic 网络对于当前真实 (states, actions) 的 Q 值的均值。'''qf1_loss_2 = self.critic_1(states, actions).mean()qf2_loss_2 = self.critic_2(states, actions).mean()'''将原本 SAC 中的 critic 损失(均方误差)与 CQL 正则项相加,构成最终 critic 损失。'''qf1_loss = critic_1_loss + self.beta * (qf1_loss_1 - qf1_loss_2)qf2_loss = critic_2_loss + self.beta * (qf2_loss_1 - qf2_loss_2)'''对 critic_1 的损失进行反向传播更新。对 critic_2 的损失进行反向传播更新。'''self.critic_1_optimizer.zero_grad()qf1_loss.backward(retain_graph=True)self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()qf2_loss.backward(retain_graph=True)self.critic_2_optimizer.step()# 更新策略网络'''使用当前策略对真实状态采样动作,并获得对应对数概率。self.actor未更新'''new_actions, log_prob = self.actor(states)entropy = -log_prob'''评估当前策略生成的动作的 Q 值,使用 critic_1 和 critic_2 分别计算,并取最小值以降低过估计。self.critic_1和self.critic_2是刚更新过的'''q1_value = self.critic_1(states, new_actions)q2_value = self.critic_2(states, new_actions)'''作用:- 计算策略损失,目标是最大化 min(𝑄1,𝑄2)−𝛼log𝜋;这里取负号后作为损失。'''actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值'''作用:- 计算温度参数 α 的损失,目标是使策略熵接近目标熵。- detach() 表示不对 entropy 反向传播梯度,只更新 alpha。SAC原始论文中:J(α)=E_{a∼π}[−α(logπ(a∣s)+Htarget)],其中H=−logπ(a∣s)'''alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)random.seed(0)
np.random.seed(0)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)beta = 5.0
num_random = 5
num_epochs = 100
num_trains_per_epoch = 500agent = CQL(state_dim, hidden_dim, action_dim, action_bound, actor_lr,critic_lr, alpha_lr, target_entropy, tau, gamma, device, beta,num_random)return_list = []
for i in range(10):with tqdm(total=int(num_epochs / 10), desc='Iteration %d' % i) as pbar:for i_epoch in range(int(num_epochs / 10)):# 此处与环境交互只是为了评估策略,最后作图用,不会用于训练epoch_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepoch_return += rewardreturn_list.append(epoch_return)for _ in range(num_trains_per_epoch):b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)if (i_epoch + 1) % 10 == 0:pbar.set_postfix({'epoch':'%d' % (num_epochs / 10 * i + i_epoch + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)epochs_list = list(range(len(return_list)))
plt.plot(epochs_list, return_list)
plt.xlabel('Epochs')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()