注意:inner policy的训练算法只是基本的PG,所以训练过程极不稳定。如有需要可以自己试试调参,或者把inner policy的训练算法改成更稳定的比如PPO等方法。
import numpy as np
import torch
import torch.nn as nnimport gymimport torch.nn.functional as Ffrom torch.distributions.categorical import Categoricalclass NN(nn.Module):def __init__(self, state_size, action_size, hidden_size, num_options):super().__init__()self.actors = nn.ModuleList([nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, action_size),nn.Softmax(dim=-1)) for _ in range(num_options)])self.terminations = nn.ModuleList([nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1),nn.Sigmoid()) for _ in range(num_options)])self.critics = nn.ModuleList([nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, action_size),) for _ in range(num_options)])def select_option(self, state, epsilon):if np.random.rand() >= epsilon:max_value = - np.infoption_id = -1for i, (a, c) in enumerate(zip(self.actors, self.critics)):q = c(state)p = a(state)v = (q * p).sum(-1).item()if v >= max_value:option_id = imax_value = velse:option_id = np.random.randint(0, len(self.actors))return self.actors[option_id], self.terminations[option_id], option_idif __name__ == '__main__':np.random.seed(0)episodes = 5000epsilon = 1.0discount = 0.9epsilon_decay = 0.995epsilon_min = 0.05training_epochs = 1env = gym.make('CartPole-v1')nn = NN(4, 2, 128, 6)optimizer = torch.optim.Adam(nn.parameters(), lr=1e-2)max_score = 0.0trajectory = []for e in range(1, episodes + 1):if e % training_epochs == 0:trajectory = []score = 0.0state, _ = env.reset()option = nn.select_option(torch.tensor(state), epsilon)while True:policy = option[0](torch.tensor(state))action = Categorical(policy).sample()next_state, reward, done, _, _ = env.step(action.detach().numpy())score += rewardbeta = option[1](torch.tensor(next_state)).item()if np.random.rand() > beta:trajectory.append((state, action, reward, next_state, done, option[2], beta, False))else:trajectory.append((state, action, reward, next_state, done, option[2], beta, True))option = nn.select_option(torch.tensor(next_state), epsilon)state = next_stateif done: breakif e % training_epochs == 0:optimizer.zero_grad()q_targets = []option_states = []option_advs = []option_next_states = []for state, action, reward, next_state, done, option_id, beta, option_terminal in trajectory:q = reward + (1 - done) * discount * ((1 - beta) * (nn.critics[option_id](torch.tensor(next_state)) *nn.actors[option_id](torch.tensor(next_state))).sum(-1).item() +beta * max([(nn.critics[i](torch.tensor(next_state)) *nn.actors[i](torch.tensor(next_state))).sum(-1).item()for i in range(len(nn.critics))]))q_target = nn.critics[option_id](torch.tensor(state)).detach().numpy()q_target[action] = qq_targets.append(q_target)option_states.append(state)inner_next_value = (nn.critics[option_id](torch.tensor(next_state)).detach().numpy() *nn.actors[option_id](torch.tensor(next_state)).detach().numpy()).sum(-1).item()next_value = max([(nn.critics[i](torch.tensor(next_state)).detach().numpy() *nn.actors[i](torch.tensor(next_state)).detach().numpy()).sum(-1).item() for i in range(len(nn.critics))])option_adv = inner_next_value - next_valueoption_advs.append(option_adv)option_next_states.append(next_state)if option_terminal:option_states = torch.tensor(np.array(option_states))q_targets = torch.tensor(np.array(q_targets))option_advs = torch.tensor(np.array(option_advs)).view(-1, 1)option_next_states = torch.tensor(np.array(option_next_states))option_critic_loss = F.mse_loss(nn.critics[option_id](option_states),q_targets)actor_advs = q_targets - nn.critics[option_id](option_states).detach()option_actor_loss = - (torch.log(nn.actors[option_id](option_states)) * actor_advs).mean()option_terminal_loss = (nn.terminations[option_id](option_next_states) * option_advs).mean()option_critic_loss.backward()option_actor_loss.backward()option_terminal_loss.backward()q_targets = []option_states = []option_advs = []option_next_states = []optimizer.step()if epsilon > epsilon_min:epsilon *= epsilon_decayif score > max_score:max_score = scoretorch.save(nn, 'NN.pt')print("Episode: {}/{}, Epsilon: {}, Score: {}, Max score: {}".format(e, episodes, epsilon, score, max_score))