-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain.py
64 lines (48 loc) · 1.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import importlib
import argparse
import sys
import click
from copy import deepcopy
from rl_sumo.helpers.register_environment import make_create_env
from rl_sumo.helpers.preprocessing import execute_preprocessing_tasks
from rl_sumo.helpers.preprocessing import get_parameters
from trainers import TRAINING_FUNCTIONS
def preprocessing(sim_params, *args, **kwargs):
"""
Execute preprocessing tasks. They should be passed to the configuration like
{
...
"Simulation": {
"pre_processing_tasks": [
{"python_path": "tools.preprocessing.my_custom_function",
"module_path": "<absolute path to module root>"
}, ...
],
"file_root": ...
}
The entire simulation object is passed to the function.
"""
# add the root location to the path
if sim_params["pre_processing_tasks"]:
for task in sim_params.pre_processing_tasks:
sys.path.insert(0, task["module_path"])
execute_preprocessing_tasks([[task["module_path"], (sim_params,)]])
@click.option(
"--config_path",
help="Path to the JSON configuration file",
)
def _main(config_path):
"""
This script runs the desired RL training.
Before it does the training, it will create environment and simulation parameter classes based on the configuration file input and run desired preprocessing tasks
The actual RL training functions should live in ./trainers/training_functions.py
"""
# get the sim and environment parameters
env_params, sim_params = get_parameters(config_path)
# preprocessing
preprocessing(sim_params)
TRAINING_FUNCTIONS[env_params.algorithm.lower()](sim_params, env_params)
# this is to bypass the pylint errors
main = click.command()(_main)
if __name__ == "__main__":
main()