-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
99 lines (79 loc) · 3.05 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import os
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml
from run import run
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger()
ex = Experiment("pymarl")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds
results_path = os.path.join(dirname(dirname(abspath(__file__))), "A_DATA/QMIX_Share")
@ex.main
def my_main(_run, _config, _log):
# Setting the random seed throughout the modules
config = config_copy(_config)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
config['env_args']['seed'] = config["seed"]
# run the framework
run(_run, config, _log)
def _get_config(params, arg_name, subfolder):
config_name = None
for _i, _v in enumerate(params):
if _v.split("=")[0] == arg_name:
config_name = _v.split("=")[1]
del params[_i]
break
if config_name is not None:
with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)),
"r") as f:
try:
config_dict = yaml.safe_load(f)
except yaml.YAMLError as exc:
assert False, "{}.yaml error: {}".format(config_name, exc)
return config_dict
def recursive_dict_update(d, u):
for k, v in u.items():
if isinstance(v, collections.Mapping):
d[k] = recursive_dict_update(d.get(k, {}), v)
else:
d[k] = v
return d
def config_copy(config):
if isinstance(config, dict):
return {k: config_copy(v) for k, v in config.items()}
elif isinstance(config, list):
return [config_copy(v) for v in config]
else:
return deepcopy(config)
if __name__ == '__main__':
params = deepcopy(sys.argv)
# Get the defaults from default.yaml
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
try:
config_dict = yaml.safe_load(f)
except yaml.YAMLError as exc:
assert False, "default.yaml error: {}".format(exc)
# Load algorithm and env base configs
env_config = _get_config(params, "--env-config", "envs")
alg_config = _get_config(params, "--config", "algs")
# config_dict = {**config_dict, **env_config, **alg_config}
config_dict = recursive_dict_update(config_dict, env_config)
config_dict = recursive_dict_update(config_dict, alg_config)
# now add all the config to sacred
ex.add_config(config_dict)
# Save to disk by default for sacred
logger.info("Saving to FileStorageObserver in results/sacred.")
file_obs_path = os.path.join(results_path, "sacred")
ex.observers.append(FileStorageObserver.create(file_obs_path))
ex.run_commandline(params)