|
4 | 4 | from train import train_ppo
|
5 | 5 |
|
6 | 6 | sweep_config = {
|
7 |
| - "method": "bayes", |
| 7 | + "method": "random", |
8 | 8 | "metric": {"name": "eval/mean_rewards", "goal": "maximize"},
|
9 | 9 | "parameters": {
|
10 |
| - "n_envs": {"values": [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]}, |
11 |
| - "n_train_samples": {"values": [int(2**n) for n in range(15, 20)]}, |
12 |
| - "learning_rate": { |
13 |
| - "distribution": "log_uniform", |
14 |
| - "min": -10, # e⁻¹⁰ ~= 5e-5 |
15 |
| - "max": -5, # e⁻⁵ ~= 6e-3 |
16 |
| - }, |
17 |
| - "n_minibatches": {"values": [8, 16, 32, 64, 128]}, |
18 |
| - "n_epochs": {"values": [5, 10, 15]}, |
19 |
| - "clip_coef": {"distribution": "uniform", "min": 0.1, "max": 0.3}, |
20 |
| - "ent_coef": {"distribution": "uniform", "min": 0.0, "max": 0.25}, |
21 |
| - "vf_coef": {"distribution": "uniform", "min": 0.4, "max": 0.6}, |
22 |
| - "gamma": {"distribution": "uniform", "min": 0.9, "max": 0.999}, |
23 |
| - "gae_lambda": {"distribution": "uniform", "min": 0.5, "max": 0.99}, |
24 |
| - "max_grad_norm": {"distribution": "uniform", "min": 0.2, "max": 5.0}, |
| 10 | + "learning_rate": {"distribution": "uniform", "min": 1e-4, "max": 5e-3}, |
| 11 | + "clip_coef": {"distribution": "uniform", "min": 0.2, "max": 0.3}, |
| 12 | + "ent_coef": {"distribution": "uniform", "min": 0.0, "max": 0.05}, |
| 13 | + "gamma": {"distribution": "uniform", "min": 0.8, "max": 0.99}, |
| 14 | + "gae_lambda": {"distribution": "uniform", "min": 0.9, "max": 0.99}, |
| 15 | + "max_grad_norm": {"distribution": "uniform", "min": 1.0, "max": 5.0}, |
25 | 16 | },
|
26 | 17 | }
|
27 | 18 |
|
28 | 19 | config = ConfigDict(
|
29 | 20 | {
|
30 |
| - "n_envs": 32, |
| 21 | + "n_envs": 1024, |
31 | 22 | "device": "cuda",
|
32 |
| - "total_timesteps": 2_000_000, |
33 |
| - "learning_rate": 3e-4, |
34 |
| - "n_steps": 2048, # Number of steps per environment per policy rollout |
35 |
| - "gamma": 0.99, # Discount factor |
| 23 | + "total_timesteps": 1_000_000, |
| 24 | + "learning_rate": 1.5e-3, |
| 25 | + "n_steps": 16, # Number of steps per environment per policy rollout |
| 26 | + "gamma": 0.90, # Discount factor |
36 | 27 | "gae_lambda": 0.95, # Lambda for general advantage estimation
|
37 |
| - "n_minibatches": 32, # Number of mini-batches |
38 |
| - "n_epochs": 10, |
| 28 | + "n_minibatches": 16, # Number of mini-batches |
| 29 | + "n_epochs": 15, |
39 | 30 | "norm_adv": True,
|
40 |
| - "clip_coef": 0.2, |
| 31 | + "clip_coef": 0.25, |
41 | 32 | "clip_vloss": True,
|
42 |
| - "ent_coef": 0.0, |
| 33 | + "ent_coef": 0.01, |
43 | 34 | "vf_coef": 0.5,
|
44 |
| - "max_grad_norm": 0.5, |
| 35 | + "max_grad_norm": 5.0, |
45 | 36 | "target_kl": None,
|
46 | 37 | "seed": 0,
|
47 | 38 | "n_eval_envs": 64,
|
48 | 39 | "n_eval_steps": 1_000,
|
49 | 40 | "save_model": False,
|
50 |
| - "eval_interval": 40_000, |
| 41 | + "eval_interval": 999_000, |
51 | 42 | }
|
52 | 43 | )
|
53 | 44 |
|
54 | 45 |
|
55 |
| -def main(n_runs: int | None = None): |
| 46 | +def main(n_runs: int | None = None, sweep: str | None = None): |
56 | 47 | with open("wandb_api_key.secret", "r") as f:
|
57 | 48 | wandb_api_key = f.read().lstrip("\n").rstrip("\n")
|
58 | 49 | wandb.login(key=wandb_api_key)
|
| 50 | + project = "crazyflow-ppo-x" |
59 | 51 |
|
60 |
| - sweep_id = wandb.sweep(sweep_config, project="crazyflow-ppo") |
| 52 | + if sweep is None: |
| 53 | + sweep = wandb.sweep(sweep_config, project=project) |
61 | 54 |
|
62 | 55 | wandb.agent(
|
63 |
| - sweep_id, |
| 56 | + sweep, |
64 | 57 | lambda: train_ppo(config.copy_and_resolve_references(), True),
|
65 | 58 | count=n_runs,
|
66 |
| - project="crazyflow-ppo", |
| 59 | + project=project, |
67 | 60 | )
|
68 | 61 |
|
69 | 62 |
|
|
0 commit comments