Skip to content

Commit f72cc2b

Browse files
committed
Update hyperparams
1 parent d0852f1 commit f72cc2b

File tree

2 files changed

+35
-42
lines changed

2 files changed

+35
-42
lines changed

tutorials/ppo/sweep.py

+24-31
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,59 @@
44
from train import train_ppo
55

66
sweep_config = {
7-
"method": "bayes",
7+
"method": "random",
88
"metric": {"name": "eval/mean_rewards", "goal": "maximize"},
99
"parameters": {
10-
"n_envs": {"values": [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]},
11-
"n_train_samples": {"values": [int(2**n) for n in range(15, 20)]},
12-
"learning_rate": {
13-
"distribution": "log_uniform",
14-
"min": -10, # e⁻¹⁰ ~= 5e-5
15-
"max": -5, # e⁻⁵ ~= 6e-3
16-
},
17-
"n_minibatches": {"values": [8, 16, 32, 64, 128]},
18-
"n_epochs": {"values": [5, 10, 15]},
19-
"clip_coef": {"distribution": "uniform", "min": 0.1, "max": 0.3},
20-
"ent_coef": {"distribution": "uniform", "min": 0.0, "max": 0.25},
21-
"vf_coef": {"distribution": "uniform", "min": 0.4, "max": 0.6},
22-
"gamma": {"distribution": "uniform", "min": 0.9, "max": 0.999},
23-
"gae_lambda": {"distribution": "uniform", "min": 0.5, "max": 0.99},
24-
"max_grad_norm": {"distribution": "uniform", "min": 0.2, "max": 5.0},
10+
"learning_rate": {"distribution": "uniform", "min": 1e-4, "max": 5e-3},
11+
"clip_coef": {"distribution": "uniform", "min": 0.2, "max": 0.3},
12+
"ent_coef": {"distribution": "uniform", "min": 0.0, "max": 0.05},
13+
"gamma": {"distribution": "uniform", "min": 0.8, "max": 0.99},
14+
"gae_lambda": {"distribution": "uniform", "min": 0.9, "max": 0.99},
15+
"max_grad_norm": {"distribution": "uniform", "min": 1.0, "max": 5.0},
2516
},
2617
}
2718

2819
config = ConfigDict(
2920
{
30-
"n_envs": 32,
21+
"n_envs": 1024,
3122
"device": "cuda",
32-
"total_timesteps": 2_000_000,
33-
"learning_rate": 3e-4,
34-
"n_steps": 2048, # Number of steps per environment per policy rollout
35-
"gamma": 0.99, # Discount factor
23+
"total_timesteps": 1_000_000,
24+
"learning_rate": 1.5e-3,
25+
"n_steps": 16, # Number of steps per environment per policy rollout
26+
"gamma": 0.90, # Discount factor
3627
"gae_lambda": 0.95, # Lambda for general advantage estimation
37-
"n_minibatches": 32, # Number of mini-batches
38-
"n_epochs": 10,
28+
"n_minibatches": 16, # Number of mini-batches
29+
"n_epochs": 15,
3930
"norm_adv": True,
40-
"clip_coef": 0.2,
31+
"clip_coef": 0.25,
4132
"clip_vloss": True,
42-
"ent_coef": 0.0,
33+
"ent_coef": 0.01,
4334
"vf_coef": 0.5,
44-
"max_grad_norm": 0.5,
35+
"max_grad_norm": 5.0,
4536
"target_kl": None,
4637
"seed": 0,
4738
"n_eval_envs": 64,
4839
"n_eval_steps": 1_000,
4940
"save_model": False,
50-
"eval_interval": 40_000,
41+
"eval_interval": 999_000,
5142
}
5243
)
5344

5445

55-
def main(n_runs: int | None = None):
46+
def main(n_runs: int | None = None, sweep: str | None = None):
5647
with open("wandb_api_key.secret", "r") as f:
5748
wandb_api_key = f.read().lstrip("\n").rstrip("\n")
5849
wandb.login(key=wandb_api_key)
50+
project = "crazyflow-ppo-x"
5951

60-
sweep_id = wandb.sweep(sweep_config, project="crazyflow-ppo")
52+
if sweep is None:
53+
sweep = wandb.sweep(sweep_config, project=project)
6154

6255
wandb.agent(
63-
sweep_id,
56+
sweep,
6457
lambda: train_ppo(config.copy_and_resolve_references(), True),
6558
count=n_runs,
66-
project="crazyflow-ppo",
59+
project=project,
6760
)
6861

6962

tutorials/ppo/train.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
160160
wandb_api_key = f.read().lstrip("\n").rstrip("\n")
161161

162162
wandb.login(key=wandb_api_key)
163-
wandb.init(project="crazyflow-ppo", config=None)
163+
wandb.init(project="crazyflow-ppo-x", config=None)
164164
config.update(wandb.config)
165165
if config.get("n_train_samples"):
166166
config.n_steps = config.n_train_samples // config.n_envs
@@ -336,7 +336,7 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
336336
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
337337

338338
# Evaluate the agent
339-
if global_step - last_eval > config.eval_interval:
339+
if global_step - last_eval >= config.eval_interval:
340340
sync_envs(train_envs, eval_envs)
341341
eval_rewards, eval_steps = evaluate_agent(
342342
eval_envs,
@@ -375,33 +375,33 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
375375
if config.save_model:
376376
save_model(agent, optimizer, train_envs, Path(__file__).parent / "ppo_checkpoint.pt")
377377

378-
plot_results(train_rewards_hist, train_rewards_steps, eval_rewards_hist, eval_rewards_steps)
378+
# plot_results(train_rewards_hist, train_rewards_steps, eval_rewards_hist, eval_rewards_steps)
379379

380380

381381
if __name__ == "__main__":
382382
config = ConfigDict(
383383
{
384-
"n_envs": 32,
384+
"n_envs": 1024,
385385
"device": "cuda",
386-
"total_timesteps": 2_000_000,
387-
"learning_rate": 5e-3,
388-
"n_steps": 1024, # Number of steps per environment per policy rollout
386+
"total_timesteps": 1_000_000,
387+
"learning_rate": 1.5e-3,
388+
"n_steps": 16, # Number of steps per environment per policy rollout
389389
"gamma": 0.90, # Discount factor
390-
"gae_lambda": 0.90, # Lambda for general advantage estimation
391-
"n_minibatches": 8, # Number of mini-batches
390+
"gae_lambda": 0.95, # Lambda for general advantage estimation
391+
"n_minibatches": 16, # Number of mini-batches
392392
"n_epochs": 15,
393393
"norm_adv": True,
394394
"clip_coef": 0.25,
395395
"clip_vloss": True,
396-
"ent_coef": 0.0,
396+
"ent_coef": 0.01,
397397
"vf_coef": 0.5,
398398
"max_grad_norm": 5.0,
399399
"target_kl": None,
400400
"seed": 0,
401401
"n_eval_envs": 64,
402402
"n_eval_steps": 1_000,
403403
"save_model": False,
404-
"eval_interval": 40_000,
404+
"eval_interval": 999_000,
405405
}
406406
)
407407

0 commit comments

Comments
 (0)