Skip to content

Commit df3b319

Browse files
committed
Add normalization to existing examples
1 parent b0980e2 commit df3b319

File tree

10 files changed

+144
-182
lines changed

10 files changed

+144
-182
lines changed

examples/mujoco/a2c_torch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ def a2c_setup(env_name: str, config: Dict[str, Any], experiment_name, device: st
2424
lr=LinearAnnealing(3e-4, 0.0, 5_000_000 // (8 * 16)),
2525
max_grad_norm=1.0,
2626
normalize_advantage=True,
27+
use_vec_normalization=True,
28+
vec_norm_info={
29+
"norm_reward": True,
30+
"norm_obs": True,
31+
"clip_obs": 1e5,
32+
"clip_reward": 1e5,
33+
},
2734
),
2835
name="default",
2936
n_envs=16,
3037
total_timesteps=5_000_000,
3138
log_interval=256,
32-
use_vec_normalization=True,
3339
record_video=False,
3440
seed=np.random.randint(2**10, 2**30),
3541
)

examples/mujoco/lstm_a2c_torch.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,28 @@ def a2c_setup(env_name: str, config: Dict[str, Any], seed: int):
2323
lr=LinearAnnealing(3e-4, 0.0, 5_000_000 // (8 * 16)),
2424
max_grad_norm=1.0,
2525
normalize_advantage=True,
26+
use_vec_normalization=True,
27+
vec_norm_info={
28+
"norm_reward": True,
29+
"norm_obs": True,
30+
"clip_obs": 1e5,
31+
"clip_reward": 1e5,
32+
},
2633
),
2734
n_envs=16,
2835
total_timesteps=5_000_000,
2936
log_interval=256,
30-
device="cpu",
37+
record_video=False,
38+
seed=np.random.randint(2**10, 2**30),
3139
)
3240

3341
if __name__ == "__main__":
3442
parser = ArgumentParser("LSTM A2C Mujoco")
3543
add_arguments(parser)
3644
cli_args = parser.parse_args()
37-
parallel_run(a2c_setup, lstm_a2c_mujoco_config, n_procs=cli_args.n_procs,
38-
env_names=cli_args.env_names, n_seeds=cli_args.n_seeds)
45+
parallel_run(a2c_setup,
46+
lstm_a2c_mujoco_config,
47+
n_procs=cli_args.n_procs,
48+
env_names=cli_args.env_names,
49+
experiment_name=cli_args.experiment_name,
50+
cuda_devices=cli_args.cuda_devices)

examples/mujoco/lstm_ppo_torch.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict
22
from argparse import ArgumentParser
33
import multiprocessing as mp
4+
import numpy as np
45

56
from modular_baselines.algorithms.ppo.ppo import LstmPPO, LstmPPOArgs
67
from modular_baselines.algorithms.ppo.torch_lstm_agent import TorchLstmPPOAgent
@@ -10,11 +11,11 @@
1011
from torch_setup import MujocoTorchConfig, setup, parallel_run, add_arguments
1112

1213

13-
def lstm_ppo_setup(experiment_name: str, env_name: str, config: MujocoTorchConfig, seed: int, device:str):
14-
return setup(LstmPPO, TorchLstmPPOAgent, LSTMSeparateNetwork, experiment_name, env_name, config, seed, device)
14+
def lstm_ppo_setup(env_name: str, config: Dict[str, Any], experiment_name: str, device: str):
15+
return setup(LstmPPO, TorchLstmPPOAgent, LSTMSeparateNetwork, experiment_name, env_name, config, device)
1516

1617

17-
lstm_ppo_mujoco_config = [MujocoTorchConfig(
18+
lstm_ppo_mujoco_config = MujocoTorchConfig(
1819
args=LstmPPOArgs(
1920
rollout_len=2048,
2021
ent_coef=1e-4,
@@ -24,23 +25,34 @@ def lstm_ppo_setup(experiment_name: str, env_name: str, config: MujocoTorchConfi
2425
epochs=10,
2526
lr=LinearAnnealing(3e-4, 0.0, 5_000_000 // (2048 * 16)),
2627
clip_value=LinearAnnealing(0.2, 0.2, 5_000_000 // (2048 * 16)),
27-
batch_size=64 // n_step,
28+
batch_size=64 // 16,
2829
max_grad_norm=1.0,
2930
normalize_advantage=True,
30-
mini_rollout_size=n_step,
31+
mini_rollout_size=16,
3132
use_sampled_hidden=False,
33+
use_vec_normalization=True,
34+
vec_norm_info={
35+
"norm_reward": True,
36+
"norm_obs": True,
37+
"clip_obs": 1e5,
38+
"clip_reward": 1e5,
39+
},
3240
),
33-
name=f"{n_step}_step",
41+
name=f"default_{16}_step",
3442
n_envs=16,
3543
total_timesteps=5_000_000,
3644
log_interval=1,
37-
) for n_step in (1, 2, 4, 8, 16, 32, 64)]
45+
record_video=False,
46+
seed=np.random.randint(2**10, 2**30))
3847

3948
if __name__ == "__main__":
4049
mp.set_start_method("spawn")
4150
parser = ArgumentParser("PPO Mujoco")
4251
add_arguments(parser)
4352
cli_args = parser.parse_args()
44-
parallel_run(lstm_ppo_setup, lstm_ppo_mujoco_config, n_procs=cli_args.n_procs,
45-
env_names=cli_args.env_names, experiment_name=cli_args.experiment_name,
46-
n_seeds=cli_args.n_seeds, cuda_devices=cli_args.cuda_devices)
53+
parallel_run(lstm_ppo_setup,
54+
lstm_ppo_mujoco_config,
55+
n_procs=cli_args.n_procs,
56+
env_names=cli_args.env_names,
57+
experiment_name=cli_args.experiment_name,
58+
cuda_devices=cli_args.cuda_devices)

examples/mujoco/ppo_torch.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict
22
from argparse import ArgumentParser
3+
import numpy as np
34

45
from modular_baselines.algorithms.ppo.ppo import PPO, PPOArgs
56
from modular_baselines.algorithms.ppo.torch_agent import TorchPPOAgent
@@ -9,8 +10,8 @@
910
from torch_setup import MujocoTorchConfig, setup, parallel_run, add_arguments
1011

1112

12-
def ppo_setup(env_name: str, config: Dict[str, Any], seed: int):
13-
return setup(PPO, TorchPPOAgent, SeparateFeatureNetwork, env_name, config, seed)
13+
def ppo_setup(env_name: str, config: Dict[str, Any], experiment_name: str, device: str):
14+
return setup(PPO, TorchPPOAgent, SeparateFeatureNetwork, experiment_name, env_name, config, device)
1415

1516

1617
ppo_mujoco_config = MujocoTorchConfig(
@@ -26,16 +27,29 @@ def ppo_setup(env_name: str, config: Dict[str, Any], seed: int):
2627
batch_size=64,
2728
max_grad_norm=1.0,
2829
normalize_advantage=True,
30+
use_vec_normalization=True,
31+
vec_norm_info={
32+
"norm_reward": True,
33+
"norm_obs": True,
34+
"clip_obs": 1e5,
35+
"clip_reward": 1e5,
36+
},
2937
),
3038
n_envs=16,
39+
name="default",
3140
total_timesteps=5_000_000,
3241
log_interval=1,
33-
device="cpu",
42+
record_video=False,
43+
seed=np.random.randint(2**10, 2**30),
3444
)
3545

3646
if __name__ == "__main__":
3747
parser = ArgumentParser("PPO Mujoco")
3848
add_arguments(parser)
3949
cli_args = parser.parse_args()
40-
parallel_run(ppo_setup, ppo_mujoco_config, n_procs=cli_args.n_procs,
41-
env_names=cli_args.env_names, n_seeds=cli_args.n_seeds)
50+
parallel_run(ppo_setup,
51+
ppo_mujoco_config,
52+
n_procs=cli_args.n_procs,
53+
env_names=cli_args.env_names,
54+
experiment_name=cli_args.experiment_name,
55+
cuda_devices=cli_args.cuda_devices)

model_train.py

Lines changed: 0 additions & 152 deletions
This file was deleted.

modular_baselines/algorithms/a2c/a2c.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class A2CArgs():
2424
lr: Coefficient
2525
max_grad_norm: float
2626
normalize_advantage: bool
27+
use_vec_normalization: bool
28+
vec_norm_info: Dict[str, Union[float, bool, int, str]]
2729

2830

2931
class A2C(OnPolicyAlgorithm):
@@ -104,15 +106,31 @@ def setup(env: VecEnv,
104106
"""
105107
observation_space, action_space, action_dim = A2C._setup(env)
106108

109+
normalizer_struct = []
110+
if args.use_vec_normalization:
111+
normalizer_struct = [
112+
("reward_rms_var", np.float32, (1,)),
113+
("obs_rms_mean", np.float32, observation_space.shape),
114+
("obs_rms_var", np.float32, observation_space.shape),
115+
("next_obs_rms_mean", np.float32, observation_space.shape),
116+
("next_obs_rms_var", np.float32, observation_space.shape),
117+
]
107118
struct = np.dtype([
108119
("observation", np.float32, observation_space.shape),
109120
("next_observation", np.float32, observation_space.shape),
110121
("action", action_space.dtype, (action_dim,)),
111122
("reward", np.float32, (1,)),
112123
("termination", np.float32, (1,)),
124+
*normalizer_struct
113125
])
114126
buffer = Buffer(struct, args.rollout_len, env.num_envs, data_logger, buffer_callbacks)
115-
collector = RolloutCollector(env, buffer, agent, data_logger, collector_callbacks)
127+
collector = RolloutCollector(
128+
env=env,
129+
buffer=buffer,
130+
agent=agent,
131+
logger=data_logger,
132+
store_normalizer_stats=args.use_vec_normalization,
133+
callbacks=collector_callbacks)
116134
return A2C(
117135
agent=agent,
118136
collector=collector,
@@ -153,6 +171,15 @@ def setup(env: VecEnv,
153171
"""
154172
observation_space, action_space, action_dim = A2C._setup(env)
155173

174+
normalizer_struct = []
175+
if args.use_vec_normalization:
176+
normalizer_struct = [
177+
("reward_rms_var", np.float32, (1,)),
178+
("obs_rms_mean", np.float32, observation_space.shape),
179+
("obs_rms_var", np.float32, observation_space.shape),
180+
("next_obs_rms_mean", np.float32, observation_space.shape),
181+
("next_obs_rms_var", np.float32, observation_space.shape),
182+
]
156183
struct = np.dtype([
157184
("observation", np.float32, observation_space.shape),
158185
("next_observation", np.float32, observation_space.shape),
@@ -166,7 +193,14 @@ def setup(env: VecEnv,
166193
])
167194

168195
buffer = Buffer(struct, args.rollout_len, env.num_envs, data_logger, buffer_callbacks)
169-
collector = RecurrentRolloutCollector(env, buffer, agent, data_logger, collector_callbacks)
196+
collector = RecurrentRolloutCollector(
197+
env=env,
198+
buffer=buffer,
199+
agent=agent,
200+
logger=data_logger,
201+
store_normalizer_stats=args.use_vec_normalization,
202+
callbacks=collector_callbacks
203+
)
170204
return LstmA2C(
171205
agent=agent,
172206
collector=collector,

0 commit comments

Comments
 (0)