forked from vwxyzjn/cleanrl
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Parallel Q-Networks algorithm (PQN) (vwxyzjn#472)
- Loading branch information
1 parent
38c313f
commit e648ee2
Showing
10 changed files
with
1,255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
poetry install | ||
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ | ||
--command "poetry run python cleanrl/pqn.py --no_cuda --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
poetry install -E envpool | ||
poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--command "poetry run python cleanrl/pqn_atari_envpool.py --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
poetry install -E envpool | ||
poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--command "poetry run python cleanrl/pqn_atari_envpool_lstm.py --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 2 \ | ||
--output-filename benchmark/cleanrl/pqn \ | ||
--scan-history | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn_atari_envpool?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 3 \ | ||
--rliable \ | ||
--rc.score_normalization_method maxmin \ | ||
--rc.normalized_score_threshold 1.0 \ | ||
--rc.sample_efficiency_plots \ | ||
--rc.sample_efficiency_and_walltime_efficiency_method Median \ | ||
--rc.performance_profile_plots \ | ||
--rc.aggregate_metrics_plots \ | ||
--rc.sample_efficiency_num_bootstrap_reps 10 \ | ||
--rc.performance_profile_num_bootstrap_reps 10 \ | ||
--rc.interval_estimates_num_bootstrap_reps 10 \ | ||
--output-filename static/0compare \ | ||
--scan-history | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn_atari_envpool_lstm?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 MsPacman-v5 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 3 \ | ||
--rliable \ | ||
--rc.score_normalization_method maxmin \ | ||
--rc.normalized_score_threshold 1.0 \ | ||
--rc.sample_efficiency_plots \ | ||
--rc.sample_efficiency_and_walltime_efficiency_method Median \ | ||
--rc.performance_profile_plots \ | ||
--rc.aggregate_metrics_plots \ | ||
--rc.sample_efficiency_num_bootstrap_reps 10 \ | ||
--rc.performance_profile_num_bootstrap_reps 10 \ | ||
--rc.interval_estimates_num_bootstrap_reps 10 \ | ||
--output-filename static/0compare \ | ||
--scan-history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy | ||
import os | ||
import random | ||
import time | ||
from dataclasses import dataclass | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import tyro | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
@dataclass | ||
class Args: | ||
exp_name: str = os.path.basename(__file__)[: -len(".py")] | ||
"""the name of this experiment""" | ||
seed: int = 1 | ||
"""seed of the experiment""" | ||
torch_deterministic: bool = True | ||
"""if toggled, `torch.backends.cudnn.deterministic=False`""" | ||
cuda: bool = True | ||
"""if toggled, cuda will be enabled by default""" | ||
track: bool = False | ||
"""if toggled, this experiment will be tracked with Weights and Biases""" | ||
wandb_project_name: str = "cleanRL" | ||
"""the wandb's project name""" | ||
wandb_entity: str = None | ||
"""the entity (team) of wandb's project""" | ||
capture_video: bool = False | ||
"""whether to capture videos of the agent performances (check out `videos` folder)""" | ||
|
||
# Algorithm specific arguments | ||
env_id: str = "CartPole-v1" | ||
"""the id of the environment""" | ||
total_timesteps: int = 500000 | ||
"""total timesteps of the experiments""" | ||
learning_rate: float = 2.5e-4 | ||
"""the learning rate of the optimizer""" | ||
num_envs: int = 4 | ||
"""the number of parallel game environments""" | ||
num_steps: int = 128 | ||
"""the number of steps to run for each environment per update""" | ||
num_minibatches: int = 4 | ||
"""the number of mini-batches""" | ||
update_epochs: int = 4 | ||
"""the K epochs to update the policy""" | ||
anneal_lr: bool = True | ||
"""Toggle learning rate annealing""" | ||
gamma: float = 0.99 | ||
"""the discount factor gamma""" | ||
start_e: float = 1 | ||
"""the starting epsilon for exploration""" | ||
end_e: float = 0.05 | ||
"""the ending epsilon for exploration""" | ||
exploration_fraction: float = 0.5 | ||
"""the fraction of `total_timesteps` it takes from start_e to end_e""" | ||
max_grad_norm: float = 10.0 | ||
"""the maximum norm for the gradient clipping""" | ||
q_lambda: float = 0.65 | ||
"""the lambda for Q(lambda)""" | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
if capture_video and idx == 0: | ||
env = gym.make(env_id, render_mode="rgb_array") | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
else: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
env.action_space.seed(seed) | ||
|
||
return env | ||
|
||
return thunk | ||
|
||
|
||
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | ||
torch.nn.init.orthogonal_(layer.weight, std) | ||
torch.nn.init.constant_(layer.bias, bias_const) | ||
return layer | ||
|
||
|
||
# ALGO LOGIC: initialize agent here: | ||
class QNetwork(nn.Module): | ||
def __init__(self, env): | ||
super().__init__() | ||
|
||
self.network = nn.Sequential( | ||
layer_init(nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)), | ||
nn.LayerNorm(120), | ||
nn.ReLU(), | ||
layer_init(nn.Linear(120, 84)), | ||
nn.LayerNorm(84), | ||
nn.ReLU(), | ||
layer_init(nn.Linear(84, env.single_action_space.n)), | ||
) | ||
|
||
def forward(self, x): | ||
return self.network(x) | ||
|
||
|
||
def linear_schedule(start_e: float, end_e: float, duration: int, t: int): | ||
slope = (end_e - start_e) / duration | ||
return max(slope * t + start_e, end_e) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = tyro.cli(Args) | ||
args.batch_size = int(args.num_envs * args.num_steps) | ||
args.minibatch_size = int(args.batch_size // args.num_minibatches) | ||
args.num_iterations = args.total_timesteps // args.batch_size | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv( | ||
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] | ||
) | ||
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
|
||
# agent setup | ||
q_network = QNetwork(envs).to(device) | ||
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate) | ||
|
||
# storage setup | ||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) | ||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) | ||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
values = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
global_step = 0 | ||
start_time = time.time() | ||
next_obs, _ = envs.reset(seed=args.seed) | ||
next_obs = torch.Tensor(next_obs).to(device) | ||
next_done = torch.zeros(args.num_envs).to(device) | ||
|
||
for iteration in range(1, args.num_iterations + 1): | ||
# Annealing the rate if instructed to do so. | ||
if args.anneal_lr: | ||
frac = 1.0 - (iteration - 1.0) / args.num_iterations | ||
lrnow = frac * args.learning_rate | ||
optimizer.param_groups[0]["lr"] = lrnow | ||
|
||
for step in range(0, args.num_steps): | ||
global_step += args.num_envs | ||
obs[step] = next_obs | ||
dones[step] = next_done | ||
|
||
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) | ||
random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device) | ||
with torch.no_grad(): | ||
q_values = q_network(next_obs) | ||
max_actions = torch.argmax(q_values, dim=1) | ||
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten() | ||
|
||
explore = torch.rand((args.num_envs,)).to(device) < epsilon | ||
action = torch.where(explore, random_actions, max_actions) | ||
actions[step] = action | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) | ||
next_done = np.logical_or(terminations, truncations) | ||
rewards[step] = torch.tensor(reward).to(device).view(-1) | ||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) | ||
|
||
if "final_info" in infos: | ||
for info in infos["final_info"]: | ||
if info and "episode" in info: | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
|
||
# Compute Q(lambda) targets | ||
with torch.no_grad(): | ||
returns = torch.zeros_like(rewards).to(device) | ||
for t in reversed(range(args.num_steps)): | ||
if t == args.num_steps - 1: | ||
next_value, _ = torch.max(q_network(next_obs), dim=-1) | ||
nextnonterminal = 1.0 - next_done | ||
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal | ||
else: | ||
nextnonterminal = 1.0 - dones[t + 1] | ||
next_value = values[t + 1] | ||
returns[t] = rewards[t] + args.gamma * ( | ||
args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal | ||
) | ||
|
||
# flatten the batch | ||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) | ||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape) | ||
b_returns = returns.reshape(-1) | ||
|
||
# Optimizing the Q-network | ||
b_inds = np.arange(args.batch_size) | ||
for epoch in range(args.update_epochs): | ||
np.random.shuffle(b_inds) | ||
for start in range(0, args.batch_size, args.minibatch_size): | ||
end = start + args.minibatch_size | ||
mb_inds = b_inds[start:end] | ||
|
||
old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze() | ||
loss = F.mse_loss(b_returns[mb_inds], old_val) | ||
|
||
# optimize the model | ||
optimizer.zero_grad() | ||
loss.backward() | ||
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm) | ||
optimizer.step() | ||
|
||
writer.add_scalar("losses/td_loss", loss, global_step) | ||
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Oops, something went wrong.