-
Notifications
You must be signed in to change notification settings - Fork 213
/
Copy pathppo_4x4grid.py
executable file
·70 lines (61 loc) · 1.79 KB
/
ppo_4x4grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys
if "SUMO_HOME" in os.environ:
tools = os.path.join(os.environ["SUMO_HOME"], "tools")
sys.path.append(tools)
else:
sys.exit("Please declare the environment variable 'SUMO_HOME'")
import numpy as np
import pandas as pd
import ray
import traci
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.tune.registry import register_env
import sumo_rl
if __name__ == "__main__":
print(os.getcwd())
ray.init()
env_name = "4x4grid"
register_env(
env_name,
lambda _: ParallelPettingZooEnv(
sumo_rl.parallel_env(
net_file="sumo_rl/nets/4x4-Lucas/4x4.net.xml",
route_file="sumo_rl/nets/4x4-Lucas/4x4c1c2c1c2.rou.xml",
out_csv_name="outputs/4x4grid/ppo",
use_gui=False,
num_seconds=80000,
)
),
)
config = (
PPOConfig()
.environment(env=env_name, disable_env_checking=True)
.rollouts(num_rollout_workers=4, rollout_fragment_length=128)
.training(
train_batch_size=512,
lr=2e-5,
gamma=0.95,
lambda_=0.9,
use_gae=True,
clip_param=0.4,
grad_clip=None,
entropy_coeff=0.1,
vf_loss_coeff=0.25,
sgd_minibatch_size=64,
num_sgd_iter=10,
)
.debugging(log_level="ERROR")
.framework(framework="torch")
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 100000},
checkpoint_freq=10,
local_dir="~/ray_results/" + env_name,
config=config.to_dict(),
)