Skip to content

Commit b0980e2

Browse files
committed
Merge commit from research-vgsm branch.
2 parents 3e7d57d + 9bf4a0c commit b0980e2

File tree

15 files changed

+456
-95
lines changed

15 files changed

+456
-95
lines changed

examples/mujoco/a2c_torch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict
22
from argparse import ArgumentParser
3+
import numpy as np
34

45
from modular_baselines.algorithms.a2c import A2C, A2CArgs
56
from modular_baselines.algorithms.a2c.torch_agent import TorchA2CAgent
@@ -9,8 +10,8 @@
910
from torch_setup import MujocoTorchConfig, setup, parallel_run, add_arguments
1011

1112

12-
def a2c_setup(env_name: str, config: Dict[str, Any], seed: int):
13-
return setup(A2C, TorchA2CAgent, SeparateFeatureNetwork, env_name, config, seed)
13+
def a2c_setup(env_name: str, config: Dict[str, Any], experiment_name, device: str):
14+
return setup(A2C, TorchA2CAgent, SeparateFeatureNetwork, experiment_name, env_name, config, device=device)
1415

1516

1617
a2c_mujoco_config = MujocoTorchConfig(
@@ -24,15 +25,22 @@ def a2c_setup(env_name: str, config: Dict[str, Any], seed: int):
2425
max_grad_norm=1.0,
2526
normalize_advantage=True,
2627
),
28+
name="default",
2729
n_envs=16,
2830
total_timesteps=5_000_000,
2931
log_interval=256,
30-
device="cpu",
32+
use_vec_normalization=True,
33+
record_video=False,
34+
seed=np.random.randint(2**10, 2**30),
3135
)
3236

3337
if __name__ == "__main__":
3438
parser = ArgumentParser("A2C Mujoco")
3539
add_arguments(parser)
3640
cli_args = parser.parse_args()
37-
parallel_run(a2c_setup, a2c_mujoco_config, n_procs=cli_args.n_procs,
38-
env_names=cli_args.env_names, n_seeds=cli_args.n_seeds)
41+
parallel_run(a2c_setup,
42+
a2c_mujoco_config,
43+
n_procs=cli_args.n_procs,
44+
env_names=cli_args.env_names,
45+
experiment_name=cli_args.experiment_name,
46+
cuda_devices=cli_args.cuda_devices)

examples/mujoco/lstm_ppo_torch.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict
22
from argparse import ArgumentParser
3+
import multiprocessing as mp
34

45
from modular_baselines.algorithms.ppo.ppo import LstmPPO, LstmPPOArgs
56
from modular_baselines.algorithms.ppo.torch_lstm_agent import TorchLstmPPOAgent
@@ -9,11 +10,11 @@
910
from torch_setup import MujocoTorchConfig, setup, parallel_run, add_arguments
1011

1112

12-
def ppo_setup(env_name: str, config: Dict[str, Any], seed: int):
13-
return setup(LstmPPO, TorchLstmPPOAgent, LSTMSeparateNetwork, env_name, config, seed)
13+
def lstm_ppo_setup(experiment_name: str, env_name: str, config: MujocoTorchConfig, seed: int, device:str):
14+
return setup(LstmPPO, TorchLstmPPOAgent, LSTMSeparateNetwork, experiment_name, env_name, config, seed, device)
1415

1516

16-
lstm_ppo_mujoco_config = MujocoTorchConfig(
17+
lstm_ppo_mujoco_config = [MujocoTorchConfig(
1718
args=LstmPPOArgs(
1819
rollout_len=2048,
1920
ent_coef=1e-4,
@@ -23,21 +24,23 @@ def ppo_setup(env_name: str, config: Dict[str, Any], seed: int):
2324
epochs=10,
2425
lr=LinearAnnealing(3e-4, 0.0, 5_000_000 // (2048 * 16)),
2526
clip_value=LinearAnnealing(0.2, 0.2, 5_000_000 // (2048 * 16)),
26-
batch_size=8,
27+
batch_size=64 // n_step,
2728
max_grad_norm=1.0,
2829
normalize_advantage=True,
29-
mini_rollout_size=8,
30+
mini_rollout_size=n_step,
3031
use_sampled_hidden=False,
3132
),
33+
name=f"{n_step}_step",
3234
n_envs=16,
3335
total_timesteps=5_000_000,
3436
log_interval=1,
35-
device="cpu",
36-
)
37+
) for n_step in (1, 2, 4, 8, 16, 32, 64)]
3738

3839
if __name__ == "__main__":
40+
mp.set_start_method("spawn")
3941
parser = ArgumentParser("PPO Mujoco")
4042
add_arguments(parser)
4143
cli_args = parser.parse_args()
42-
parallel_run(ppo_setup, lstm_ppo_mujoco_config, n_procs=cli_args.n_procs,
43-
env_names=cli_args.env_names, n_seeds=cli_args.n_seeds)
44+
parallel_run(lstm_ppo_setup, lstm_ppo_mujoco_config, n_procs=cli_args.n_procs,
45+
env_names=cli_args.env_names, experiment_name=cli_args.experiment_name,
46+
n_seeds=cli_args.n_seeds, cuda_devices=cli_args.cuda_devices)

examples/mujoco/torch_setup.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,112 @@
1-
from typing import List, Any, Dict, Union, Optional, Tuple, Callable, Type
1+
from typing import List, Any, Dict, Union, Optional, Tuple, Callable, Type, Iterable
22
import torch
33
import os
44
import numpy as np
55
import sys
66
from multiprocessing import Process, Queue
77
from dataclasses import dataclass
88
import argparse
9+
import time
10+
import gym
11+
from datetime import datetime
912

1013
from stable_baselines3.common.env_util import make_vec_env
14+
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
1115
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
1216
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
1317
from stable_baselines3.common.logger import HumanOutputFormat, CSVOutputFormat, JSONOutputFormat
18+
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
19+
from stable_baselines3.common.running_mean_std import RunningMeanStd
1420

1521
from modular_baselines.algorithms.algorithm import BaseAlgorithm
1622
from modular_baselines.algorithms.agent import BaseAgent
17-
from modular_baselines.loggers.writers import ScalarWriter, DictWriter
23+
from modular_baselines.loggers.writers import ScalarWriter, DictWriter, BaseWriter, SaveModelParametersWriter, LogConfigs
1824
from modular_baselines.loggers.data_logger import DataLogger
1925

2026

2127
@dataclass(frozen=True)
2228
class MujocoTorchConfig():
2329
args: Any
30+
name: str
2431
n_envs: int
2532
total_timesteps: int
2633
log_interval: int
27-
device: str
28-
29-
30-
def setup(algorithm_cls: Type[BaseAlgorithm],
31-
agent_cls: Type[BaseAgent],
32-
network: Type[torch.nn.Module],
33-
env_name: str,
34-
config: MujocoTorchConfig,
35-
seed: int
36-
) -> BaseAlgorithm:
37-
np.random.seed(seed)
38-
torch.manual_seed(seed)
39-
40-
log_dir = f"logs/{algorithm_cls.__name__}-{env_name.lower()}/{seed}"
34+
record_video: bool
35+
seed: int
36+
37+
38+
def pre_setup(experiment_name: str,
39+
env: Union[gym.Env, str],
40+
config: MujocoTorchConfig,
41+
) -> Tuple[DataLogger, List[BaseWriter], VecEnv]:
42+
""" Prepare loggers and vectorized environment
43+
44+
Args:
45+
experiment_name (str): Name of the experiment
46+
env (Union[gym.Env, str]): Name of the environment or the environment
47+
config (MujocoTorchConfig): Torch Mujoco configuration
48+
49+
Returns:
50+
Tuple[DataLogger, List[BaseWriter], VecEnv]: Data logger, writers list and vectorized
51+
environment
52+
"""
53+
np.random.seed(config.seed)
54+
torch.manual_seed(config.seed)
55+
env_name = env if isinstance(env, str) else env.__class__.__name__
56+
date_time = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
57+
58+
log_dir = f"logs/{experiment_name}-{env_name.lower()}/{config.name}/{date_time}"
4159
data_logger = DataLogger()
4260
os.makedirs(log_dir, exist_ok=True)
4361
sb3_writers = [HumanOutputFormat(sys.stdout),
4462
CSVOutputFormat(os.path.join(log_dir, "progress.csv")),
4563
JSONOutputFormat(os.path.join(log_dir, "progress.json"))]
4664
logger_callbacks = [
4765
ScalarWriter(interval=config.log_interval, dir_path=log_dir, writers=sb3_writers),
48-
DictWriter(interval=config.log_interval, dir_path=log_dir)
66+
DictWriter(interval=config.log_interval, dir_path=log_dir),
67+
SaveModelParametersWriter(interval=config.log_interval * 1, dir_path=log_dir)
4968
]
5069

5170
vecenv = make_vec_env(
52-
env_name,
71+
env,
5372
n_envs=config.n_envs,
54-
seed=seed,
73+
seed=config.seed,
5574
wrapper_class=None,
5675
vec_env_cls=SubprocVecEnv)
57-
vecenv = VecNormalize(vecenv, training=True, gamma=config.args.gamma)
76+
if config.args.use_vec_normalization:
77+
vecenv = VecNormalize(
78+
vecenv,
79+
training=True,
80+
gamma=config.args.gamma,
81+
**config.args.vec_norm_info)
82+
if config.args.vec_norm_info["norm_obs"] is False:
83+
vecenv.obs_rms = RunningMeanStd(shape=vecenv.observation_space.shape)
84+
if config.record_video:
85+
vecenv = VecVideoRecorder(
86+
vecenv,
87+
f"{log_dir}/videos",
88+
record_video_trigger=lambda x: x % 25000 == 0, video_length=1000
89+
)
90+
LogConfigs(config=config, dir_path=log_dir)
91+
92+
return data_logger, logger_callbacks, vecenv
93+
94+
95+
def setup(algorithm_cls: Type[BaseAlgorithm],
96+
agent_cls: Type[BaseAgent],
97+
network: Type[torch.nn.Module],
98+
experiment_name: str,
99+
env_name: str,
100+
config: MujocoTorchConfig,
101+
device: str
102+
) -> BaseAlgorithm:
103+
104+
experiment_name = "-".join([experiment_name, algorithm_cls.__name__])
105+
data_logger, logger_callbacks, vecenv = pre_setup(experiment_name, env_name, config)
58106

59107
policy = network(observation_space=vecenv.observation_space,
60108
action_space=vecenv.action_space)
61-
policy.to(config.device)
109+
policy.to(device)
62110
optimizer = torch.optim.Adam(policy.parameters(), eps=1e-5)
63111
agent = agent_cls(policy,
64112
optimizer,
@@ -80,39 +128,48 @@ def setup(algorithm_cls: Type[BaseAlgorithm],
80128

81129

82130
def add_arguments(parser: argparse.ArgumentParser) -> None:
131+
parser.add_argument("--experiment-name", type=str, default="",
132+
help="Prefix of the experiment name")
83133
parser.add_argument("--n-procs", type=int, default=1,
84134
help="Number of parallelized processes for experiments")
85-
parser.add_argument("--n-seeds", type=int, default=1,
86-
help="Number of seeds/runs per environment")
87135
parser.add_argument("--env-names", nargs='+', type=str, required=True,
88136
help="Gym environment names")
137+
parser.add_argument("--cuda-devices", nargs='+', type=int, required=False,
138+
help="Available cuda devices")
89139

90140

91-
def worker(setup_fn, argument_queue: Queue, rank: int) -> None:
141+
def worker(setup_fn, argument_queue: Queue, rank: int, cuda_devices) -> None:
142+
device = "cpu" if cuda_devices is None else f"cuda:{cuda_devices[rank % len(cuda_devices)]}"
143+
print(f"Worker-{rank} use device: {device}")
92144
while not argument_queue.empty():
93145
kwargs = argument_queue.get()
94-
setup_fn(**kwargs)
146+
setup_fn(device=device, **kwargs)
95147

96148

97149
def parallel_run(setup_fn: Callable[[str, MujocoTorchConfig, int], BaseAlgorithm],
98-
config: MujocoTorchConfig,
150+
configs: Union[MujocoTorchConfig, Iterable[MujocoTorchConfig]],
151+
experiment_name: str,
99152
n_procs: int,
100153
env_names: Tuple[str],
101-
n_seeds: int
154+
cuda_devices: Tuple[int],
102155
) -> None:
103156

104-
arguments = [dict(env_name=env_name, seed=seed, config=config)
157+
if not isinstance(configs, Iterable):
158+
configs = [configs]
159+
160+
arguments = [dict(env_name=env_name, config=config, experiment_name=experiment_name)
105161
for env_name in env_names
106-
for seed in np.random.randint(2 ** 10, 2 ** 30, size=n_seeds).tolist()]
162+
for config in configs]
107163

108164
argument_queue = Queue()
109165
for arg in arguments:
110166
argument_queue.put(arg)
111167

112-
processes = [Process(target=worker, args=(setup_fn, argument_queue, rank))
168+
processes = [Process(target=worker, args=(setup_fn, argument_queue, rank, cuda_devices))
113169
for rank in range(n_procs)]
114170

115171
for proc in processes:
172+
time.sleep(1.5) # To avoid having the same log name
116173
proc.start()
117174

118175
for proc in processes:

0 commit comments

Comments
 (0)