在上一篇文章中,我们探讨了分布式强化学习与 IMPALA 算法,展示了如何通过并行化训练提升强化学习的效率。本文将聚焦 离线强化学习(Offline RL) 这一新兴方向,并实现 Conservative Q-Learning (CQL) 算法,利用 Minari 提供的静态数据集训练安全的强化学习策略。
一、离线强化学习与 CQL 原理
1. 离线强化学习的特点
-
无需环境交互:直接从预收集的静态数据集学习
-
数据效率高:复用历史经验(如人类演示、日志数据)
-
安全风险低:避免在线探索中的危险行为
2. CQL 核心思想
CQL 通过保守策略评估防止价值函数高估,其目标函数为:
3. 算法优势
-
防止分布偏移导致的策略退化
-
支持混合质量数据集(专家数据 + 随机数据)
-
适用于真实世界场景(如医疗、金融)
二、CQL 实现步骤(基于 Minari 数据集)
我们将使用 Minari 库中的 D4RL/door/human-v2
数据集训练策略:
-
安装 Minari 并加载数据集
-
定义保守 Q 网络
-
实现保守正则化损失
-
策略优化与评估
三、代码实现
以下是 CQL 算法的完整实现代码:
import torch
import minari
import numpy as np
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from collections import deque
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
# 1. 增强型配置类(带维度校验)
class SafeConfig:# 训练参数batch_size = 1024lr = 3e-5tau = 0.007gamma = 0.99total_epochs = 500# 网络架构hidden_dim = 768num_layers = 3dropout_rate = 0.1activation_fn = 'Mish' # 支持Mish/SiLU/ReLU# 正则化参数conservative_init = 2.5conservative_decay = 0.995min_conservative = 0.3reward_scale = 4.0# 探索参数noise_scale = 0.2noise_clip = 0.5candidate_samples = 400imitation_ratio = 0.15
# 2. 安全数据加载系统
class SafeDataset(Dataset):def __init__(self, dataset_name):# 加载原始数据dataset = minari.load_dataset(dataset_name, download=True)# 获取维度信息first_ep = dataset[0]self.state_dim = first_ep.observations[0].shape[0]self.action_dim = first_ep.actions[0].shape[0]# 数据存储self.obs, self.acts, self.rews, self.dones, self.next_obs = [], [], [], [], []for ep in dataset:self._store_episode(ep.observations[:-1],ep.actions,ep.rewards,np.logical_or(ep.terminations, ep.truncations),ep.observations[1:])# 标准化self._normalize()self.priorities = np.ones(len(self.obs)) * 1e-5def _store_episode(self, obs, acts, rews, dones, next_obs):self.obs.extend(obs)self.acts.extend(acts)self.rews.extend(rews)self.dones.extend(dones)self.next_obs.extend(next_obs)def _normalize(self):# 状态标准化self.obs_mean = np.mean(self.obs, axis=0)self.obs_std = np.std(self.obs, axis=0) + 1e-8self.obs = (self.obs - self.obs_mean) / self.obs_stdself.next_obs = (self.next_obs - self.obs_mean) / self.obs_std# 动作标准化self.act_mean = np.mean(self.acts, axis=0)self.act_std = np.std(self.acts, axis=0) + 1e-8self.acts = (self.acts - self.act_mean) / self.act_stddef update_priorities(self, indices, priorities):self.priorities[indices] = np.abs(priorities.flatten()) + 1e-5def __len__(self):return len(self.obs)def __getitem__(self, idx):return (idx,torch.FloatTensor(self.obs[idx]),torch.FloatTensor(self.acts[idx]),torch.FloatTensor(self.next_obs[idx]),torch.FloatTensor([self.rews[idx]]),torch.FloatTensor([bool(self.dones[idx])]))
# 3. 维度安全网络架构
class SafeQNetwork(torch.nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.state_dim = state_dimself.action_dim = action_dimself.input_dim = state_dim + action_dim # 关键动态计算# 主网络self.feature_net = self._build_network()self.q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)self.q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)# 目标网络self.target_net = self._build_network()self.target_q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)self.target_q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)# 初始化self._init_weights()self._update_target(1.0)def _build_network(self):layers = []input_dim = self.input_dim # 使用动态计算值for _ in range(SafeConfig.num_layers):layers.extend([torch.nn.Linear(input_dim, SafeConfig.hidden_dim),torch.nn.LayerNorm(SafeConfig.hidden_dim),self._activation(),torch.nn.Dropout(SafeConfig.dropout_rate),])input_dim = SafeConfig.hidden_dimreturn torch.nn.Sequential(*layers)def _activation(self):return {'Mish': torch.nn.Mish(),'SiLU': torch.nn.SiLU(),'ReLU': torch.nn.ReLU()}[SafeConfig.activation_fn]def _init_weights(self):for m in self.modules():if isinstance(m, torch.nn.Linear):torch.nn.init.orthogonal_(m.weight)torch.nn.init.normal_(m.bias, 0, 0.1)def forward(self, state, action):# 维度校验assert state.shape[-1] == self.state_dim, f"State dim error: {state.shape[-1]} vs {self.state_dim}"assert action.shape[-1] == self.action_dim, f"Action dim error: {action.shape[-1]} vs {self.action_dim}"x = torch.cat([state, action], dim=1)features = self.feature_net(x)return self.q1(features), self.q2(features)def target_forward(self, state, action):x = torch.cat([state, action], dim=1)features = self.target_net(x)return self.target_q1(features), self.target_q2(features)def _update_target(self, tau):with torch.no_grad():for t_param, param in zip(self.target_net.parameters(), self.feature_net.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)for t_param, param in zip(self.target_q1.parameters(), self.q1.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)for t_param, param in zip(self.target_q2.parameters(), self.q2.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)
# 4. 安全训练系统
class SafeTrainer:def __init__(self, dataset_name):# 数据系统self.dataset = SafeDataset(dataset_name)self.state_dim = self.dataset.state_dimself.action_dim = self.dataset.action_dim# 网络系统self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.q_net = SafeQNetwork(self.state_dim, self.action_dim).to(self.device)# 优化系统self.optimizer = torch.optim.AdamW(self.q_net.parameters(),lr=SafeConfig.lr,weight_decay=1e-3)self.scheduler = CosineAnnealingWarmRestarts(self.optimizer,T_0=100,eta_min=1e-6)# 数据加载self.dataloader = DataLoader(self.dataset,batch_size=SafeConfig.batch_size,sampler=WeightedRandomSampler(self.dataset.priorities,num_samples=len(self.dataset),replacement=True),collate_fn=lambda b: {'indices': torch.LongTensor([x[0] for x in b]),'states': torch.stack([x[1] for x in b]),'actions': torch.stack([x[2] for x in b]),'next_states': torch.stack([x[3] for x in b]),'rewards': torch.stack([x[4] for x in b]),'dones': torch.stack([x[5] for x in b])},num_workers=4)# 训练状态self.conservative_weight = SafeConfig.conservative_initself.loss_history = deque(maxlen=100)def train_epoch(self, epoch):self.q_net.train()total_loss = 0.0for batch in self.dataloader:# 数据准备states = batch['states'].to(self.device)actions = batch['actions'].to(self.device)next_states = batch['next_states'].to(self.device)rewards = batch['rewards'].to(self.device) * SafeConfig.reward_scaledones = batch['dones'].to(self.device)# 目标Q值计算with torch.no_grad():# 带噪声的动作生成noise = torch.randn_like(actions) * SafeConfig.noise_scalenoise = torch.clamp(noise, -SafeConfig.noise_clip, SafeConfig.noise_clip)noisy_actions = actions + noise# 双Q学习target_q1, target_q2 = self.q_net.target_forward(next_states, noisy_actions)target_q = torch.min(target_q1, target_q2).squeeze(-1)y = rewards.squeeze(-1) + (1 - dones.squeeze(-1)) * SafeConfig.gamma * target_q# 当前Q值预测current_q1, current_q2 = self.q_net(states, actions)current_q1 = current_q1.squeeze(-1).clamp(-10.0, 50.0)current_q2 = current_q2.squeeze(-1).clamp(-10.0, 50.0)# 损失计算bellman_loss = 0.5 * (torch.nn.functional.huber_loss(current_q1, y, delta=1.0) +torch.nn.functional.huber_loss(current_q2, y, delta=1.0))# 保守正则项rand_acts = torch.randn_like(actions) * SafeConfig.noise_scaleq1_rand, q2_rand = self.q_net(states, rand_acts)conservative_loss = (q1_rand + q2_rand).mean() - (current_q1 + current_q2).mean()# 总损失loss = bellman_loss + self.conservative_weight * conservative_loss# 反向传播self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 2.0)self.optimizer.step()# 更新目标网络self.q_net._update_target(SafeConfig.tau)# 更新优先级td_errors = (current_q1 - y).detach().cpu().numpy()self.dataset.update_priorities(batch['indices'].numpy(), td_errors)total_loss += loss.item()# 调整保守权重self.conservative_weight = max(self.conservative_weight * SafeConfig.conservative_decay,SafeConfig.min_conservative)# 学习率调度self.scheduler.step()return total_loss / len(self.dataloader)def get_action(self, state):self.q_net.eval()state_norm = (state - self.dataset.obs_mean) / self.dataset.obs_stdstate_tensor = torch.FloatTensor(state_norm).unsqueeze(0).to(self.device)# 候选动作生成num_imitation = int(SafeConfig.candidate_samples * SafeConfig.imitation_ratio)imitation_idx = np.random.choice(len(self.dataset), num_imitation)imitation_acts = self.dataset.acts[imitation_idx]noise_acts = np.random.randn(SafeConfig.candidate_samples - num_imitation, self.action_dim)candidates = np.concatenate([imitation_acts, noise_acts])candidates = (candidates * self.dataset.act_std) + self.dataset.act_mean# 选择最优动作with torch.no_grad():state_batch = state_tensor.repeat(SafeConfig.candidate_samples, 1)candidate_tensor = torch.FloatTensor(candidates).to(self.device)candidate_norm = (candidate_tensor - self.dataset.act_mean) / self.dataset.act_stdq_values, _ = self.q_net(state_batch, candidate_norm)best_idx = torch.argmax(q_values)return candidates[best_idx.cpu().item()]
# 5. 训练执行
if __name__ == "__main__":trainer = SafeTrainer("D4RL/door/human-v2")print(f"初始化维度检查: state={trainer.state_dim}, action={trainer.action_dim}")try:for epoch in range(SafeConfig.total_epochs):loss = trainer.train_epoch(epoch)if (epoch + 1) % 20 == 0:print(f"Epoch {epoch+1:04d} | Loss: {loss:.2f} | "f"Conserv: {trainer.conservative_weight:.2f} | "f"LR: {trainer.scheduler.get_last_lr()[0]:.1e}")except KeyboardInterrupt:print("\n训练中断,保存检查点...")torch.save(trainer.q_net.state_dict(), "interrupted.pth")print("训练完成...")
四、关键代码解析
-
数据集加载
-
使用
minari.load_dataset
加载离线数据集 -
数据集包含状态、动作、奖励、终止标志等信息
-
-
保守正则化实现
-
通过随机动作采样计算正则项
-
超参数 $\alpha$ 控制保守程度
-
-
策略提取技巧
-
采用基于 Q 值的启发式策略
-
通过多候选动作采样提升稳定性
-
五、训练结果
运行代码将观察到:
初始化维度检查: state=39, action=28
Epoch 0020 | Loss: -46.52 | Conserv: 2.26 | LR: 2.7e-05
Epoch 0040 | Loss: -73.80 | Conserv: 2.05 | LR: 2.0e-05
Epoch 0060 | Loss: -73.50 | Conserv: 1.85 | LR: 1.1e-05
Epoch 0080 | Loss: -64.76 | Conserv: 1.67 | LR: 3.8e-06
Epoch 0100 | Loss: -54.37 | Conserv: 1.51 | LR: 3.0e-05
Epoch 0120 | Loss: -59.95 | Conserv: 1.37 | LR: 2.7e-05
Epoch 0140 | Loss: -60.11 | Conserv: 1.24 | LR: 2.0e-05
Epoch 0160 | Loss: -54.49 | Conserv: 1.12 | LR: 1.1e-05
Epoch 0180 | Loss: -46.11 | Conserv: 1.01 | LR: 3.8e-06
Epoch 0200 | Loss: -37.10 | Conserv: 0.92 | LR: 3.0e-05
Epoch 0220 | Loss: -37.56 | Conserv: 0.83 | LR: 2.7e-05
Epoch 0240 | Loss: -36.40 | Conserv: 0.75 | LR: 2.0e-05
Epoch 0260 | Loss: -31.79 | Conserv: 0.68 | LR: 1.1e-05
Epoch 0280 | Loss: -24.44 | Conserv: 0.61 | LR: 3.8e-06
Epoch 0300 | Loss: -17.06 | Conserv: 0.56 | LR: 3.0e-05
Epoch 0320 | Loss: -17.40 | Conserv: 0.50 | LR: 2.7e-05
Epoch 0340 | Loss: -16.91 | Conserv: 0.45 | LR: 2.0e-05
Epoch 0360 | Loss: -12.76 | Conserv: 0.41 | LR: 1.1e-05
Epoch 0380 | Loss: -7.27 | Conserv: 0.37 | LR: 3.8e-06
Epoch 0400 | Loss: -0.27 | Conserv: 0.34 | LR: 3.0e-05
Epoch 0420 | Loss: -1.47 | Conserv: 0.30 | LR: 2.7e-05
Epoch 0440 | Loss: -2.50 | Conserv: 0.30 | LR: 2.0e-05
Epoch 0460 | Loss: -2.87 | Conserv: 0.30 | LR: 1.1e-05
Epoch 0480 | Loss: -2.64 | Conserv: 0.30 | LR: 3.8e-06
Epoch 0500 | Loss: -2.30 | Conserv: 0.30 | LR: 3.0e-05
训练完成...
六、总结与扩展
本文基于 Minari 实现了 CQL 算法的核心逻辑,展示了离线强化学习在安全关键场景的应用价值。读者可尝试以下扩展:
-
添加策略网络实现 Actor-Critic 架构
-
在
antmaze
等迷宫类数据集测试导航能力 -
实现更精确的 OOD(分布外)动作检测
在下一篇文章中,我们将探索 基于模型的强化学习(Model-Based RL),并实现 PETS 算法!
注意事项:
-
需先安装
minari
库:pip install "minari[all]"
-
数据集路径可通过
minari.list_datasets()
查看 -
调整 alpha 参数可平衡保守性与探索性
希望本文能帮助您理解离线强化学习的核心范式!欢迎在评论区分享您的实践心得。