1
- from typing import List , Any , Dict , Union , Optional , Tuple , Callable , Type
1
+ from typing import List , Any , Dict , Union , Optional , Tuple , Callable , Type , Iterable
2
2
import torch
3
3
import os
4
4
import numpy as np
5
5
import sys
6
6
from multiprocessing import Process , Queue
7
7
from dataclasses import dataclass
8
8
import argparse
9
+ import time
10
+ import gym
11
+ from datetime import datetime
9
12
10
13
from stable_baselines3 .common .env_util import make_vec_env
14
+ from stable_baselines3 .common .vec_env .base_vec_env import VecEnv
11
15
from stable_baselines3 .common .vec_env .subproc_vec_env import SubprocVecEnv
12
16
from stable_baselines3 .common .vec_env .vec_normalize import VecNormalize
13
17
from stable_baselines3 .common .logger import HumanOutputFormat , CSVOutputFormat , JSONOutputFormat
18
+ from stable_baselines3 .common .vec_env .vec_video_recorder import VecVideoRecorder
19
+ from stable_baselines3 .common .running_mean_std import RunningMeanStd
14
20
15
21
from modular_baselines .algorithms .algorithm import BaseAlgorithm
16
22
from modular_baselines .algorithms .agent import BaseAgent
17
- from modular_baselines .loggers .writers import ScalarWriter , DictWriter
23
+ from modular_baselines .loggers .writers import ScalarWriter , DictWriter , BaseWriter , SaveModelParametersWriter , LogConfigs
18
24
from modular_baselines .loggers .data_logger import DataLogger
19
25
20
26
21
27
@dataclass (frozen = True )
22
28
class MujocoTorchConfig ():
23
29
args : Any
30
+ name : str
24
31
n_envs : int
25
32
total_timesteps : int
26
33
log_interval : int
27
- device : str
28
-
29
-
30
- def setup (algorithm_cls : Type [BaseAlgorithm ],
31
- agent_cls : Type [BaseAgent ],
32
- network : Type [torch .nn .Module ],
33
- env_name : str ,
34
- config : MujocoTorchConfig ,
35
- seed : int
36
- ) -> BaseAlgorithm :
37
- np .random .seed (seed )
38
- torch .manual_seed (seed )
39
-
40
- log_dir = f"logs/{ algorithm_cls .__name__ } -{ env_name .lower ()} /{ seed } "
34
+ record_video : bool
35
+ seed : int
36
+
37
+
38
+ def pre_setup (experiment_name : str ,
39
+ env : Union [gym .Env , str ],
40
+ config : MujocoTorchConfig ,
41
+ ) -> Tuple [DataLogger , List [BaseWriter ], VecEnv ]:
42
+ """ Prepare loggers and vectorized environment
43
+
44
+ Args:
45
+ experiment_name (str): Name of the experiment
46
+ env (Union[gym.Env, str]): Name of the environment or the environment
47
+ config (MujocoTorchConfig): Torch Mujoco configuration
48
+
49
+ Returns:
50
+ Tuple[DataLogger, List[BaseWriter], VecEnv]: Data logger, writers list and vectorized
51
+ environment
52
+ """
53
+ np .random .seed (config .seed )
54
+ torch .manual_seed (config .seed )
55
+ env_name = env if isinstance (env , str ) else env .__class__ .__name__
56
+ date_time = datetime .now ().strftime ("%Y-%m-%dT%H:%M:%S" )
57
+
58
+ log_dir = f"logs/{ experiment_name } -{ env_name .lower ()} /{ config .name } /{ date_time } "
41
59
data_logger = DataLogger ()
42
60
os .makedirs (log_dir , exist_ok = True )
43
61
sb3_writers = [HumanOutputFormat (sys .stdout ),
44
62
CSVOutputFormat (os .path .join (log_dir , "progress.csv" )),
45
63
JSONOutputFormat (os .path .join (log_dir , "progress.json" ))]
46
64
logger_callbacks = [
47
65
ScalarWriter (interval = config .log_interval , dir_path = log_dir , writers = sb3_writers ),
48
- DictWriter (interval = config .log_interval , dir_path = log_dir )
66
+ DictWriter (interval = config .log_interval , dir_path = log_dir ),
67
+ SaveModelParametersWriter (interval = config .log_interval * 1 , dir_path = log_dir )
49
68
]
50
69
51
70
vecenv = make_vec_env (
52
- env_name ,
71
+ env ,
53
72
n_envs = config .n_envs ,
54
- seed = seed ,
73
+ seed = config . seed ,
55
74
wrapper_class = None ,
56
75
vec_env_cls = SubprocVecEnv )
57
- vecenv = VecNormalize (vecenv , training = True , gamma = config .args .gamma )
76
+ if config .args .use_vec_normalization :
77
+ vecenv = VecNormalize (
78
+ vecenv ,
79
+ training = True ,
80
+ gamma = config .args .gamma ,
81
+ ** config .args .vec_norm_info )
82
+ if config .args .vec_norm_info ["norm_obs" ] is False :
83
+ vecenv .obs_rms = RunningMeanStd (shape = vecenv .observation_space .shape )
84
+ if config .record_video :
85
+ vecenv = VecVideoRecorder (
86
+ vecenv ,
87
+ f"{ log_dir } /videos" ,
88
+ record_video_trigger = lambda x : x % 25000 == 0 , video_length = 1000
89
+ )
90
+ LogConfigs (config = config , dir_path = log_dir )
91
+
92
+ return data_logger , logger_callbacks , vecenv
93
+
94
+
95
+ def setup (algorithm_cls : Type [BaseAlgorithm ],
96
+ agent_cls : Type [BaseAgent ],
97
+ network : Type [torch .nn .Module ],
98
+ experiment_name : str ,
99
+ env_name : str ,
100
+ config : MujocoTorchConfig ,
101
+ device : str
102
+ ) -> BaseAlgorithm :
103
+
104
+ experiment_name = "-" .join ([experiment_name , algorithm_cls .__name__ ])
105
+ data_logger , logger_callbacks , vecenv = pre_setup (experiment_name , env_name , config )
58
106
59
107
policy = network (observation_space = vecenv .observation_space ,
60
108
action_space = vecenv .action_space )
61
- policy .to (config . device )
109
+ policy .to (device )
62
110
optimizer = torch .optim .Adam (policy .parameters (), eps = 1e-5 )
63
111
agent = agent_cls (policy ,
64
112
optimizer ,
@@ -80,39 +128,48 @@ def setup(algorithm_cls: Type[BaseAlgorithm],
80
128
81
129
82
130
def add_arguments (parser : argparse .ArgumentParser ) -> None :
131
+ parser .add_argument ("--experiment-name" , type = str , default = "" ,
132
+ help = "Prefix of the experiment name" )
83
133
parser .add_argument ("--n-procs" , type = int , default = 1 ,
84
134
help = "Number of parallelized processes for experiments" )
85
- parser .add_argument ("--n-seeds" , type = int , default = 1 ,
86
- help = "Number of seeds/runs per environment" )
87
135
parser .add_argument ("--env-names" , nargs = '+' , type = str , required = True ,
88
136
help = "Gym environment names" )
137
+ parser .add_argument ("--cuda-devices" , nargs = '+' , type = int , required = False ,
138
+ help = "Available cuda devices" )
89
139
90
140
91
- def worker (setup_fn , argument_queue : Queue , rank : int ) -> None :
141
+ def worker (setup_fn , argument_queue : Queue , rank : int , cuda_devices ) -> None :
142
+ device = "cpu" if cuda_devices is None else f"cuda:{ cuda_devices [rank % len (cuda_devices )]} "
143
+ print (f"Worker-{ rank } use device: { device } " )
92
144
while not argument_queue .empty ():
93
145
kwargs = argument_queue .get ()
94
- setup_fn (** kwargs )
146
+ setup_fn (device = device , ** kwargs )
95
147
96
148
97
149
def parallel_run (setup_fn : Callable [[str , MujocoTorchConfig , int ], BaseAlgorithm ],
98
- config : MujocoTorchConfig ,
150
+ configs : Union [MujocoTorchConfig , Iterable [MujocoTorchConfig ]],
151
+ experiment_name : str ,
99
152
n_procs : int ,
100
153
env_names : Tuple [str ],
101
- n_seeds : int
154
+ cuda_devices : Tuple [ int ],
102
155
) -> None :
103
156
104
- arguments = [dict (env_name = env_name , seed = seed , config = config )
157
+ if not isinstance (configs , Iterable ):
158
+ configs = [configs ]
159
+
160
+ arguments = [dict (env_name = env_name , config = config , experiment_name = experiment_name )
105
161
for env_name in env_names
106
- for seed in np . random . randint ( 2 ** 10 , 2 ** 30 , size = n_seeds ). tolist () ]
162
+ for config in configs ]
107
163
108
164
argument_queue = Queue ()
109
165
for arg in arguments :
110
166
argument_queue .put (arg )
111
167
112
- processes = [Process (target = worker , args = (setup_fn , argument_queue , rank ))
168
+ processes = [Process (target = worker , args = (setup_fn , argument_queue , rank , cuda_devices ))
113
169
for rank in range (n_procs )]
114
170
115
171
for proc in processes :
172
+ time .sleep (1.5 ) # To avoid having the same log name
116
173
proc .start ()
117
174
118
175
for proc in processes :
0 commit comments