-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
77 lines (59 loc) · 2.39 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tqdm
import argparse
import numpy as np
from datetime import datetime
from gym.spaces import Box, Discrete
import tensorflow.compat.v1 as tf_v1
from train import get_env
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env_name", type=str, help="Environment name", required=True)
parser.add_argument("--load_model", type=str, help=".tf model path", required=True)
parser.add_argument("--dump_path", type=str, help="Path to dump recorded videos", default=None)
parser.add_argument("--epochs", type=int, help="Number of test epochs", default=10)
parser.add_argument("--include", help="Additional modules to import", nargs="*")
args = parser.parse_args()
arg_dict = args.__dict__
return arg_dict
def discrete_policy_wrapper(policy):
def wrapper(state, **kwargs):
action = policy.predict(np.expand_dims(state, axis=0))[0]
return np.argmax(action)
return wrapper
def continuous_policy_wrapper(policy):
def wrapper(state, **kwargs):
action = policy.predict(np.expand_dims(state, axis=0))[0]
return action
return wrapper
def wrap_policy(env, policy):
if isinstance(env.action_space, Discrete):
policy = discrete_policy_wrapper(policy)
elif isinstance(env.action_space, Box):
policy = continuous_policy_wrapper(policy)
else:
raise NotImplementedError
return policy
def setup(env_name, load_model, dump_path, include):
env = get_env(env_name, record_interval=1, dump_dir=dump_path, include=include)
policy = tf_v1.keras.models.load_model(load_model)
policy = wrap_policy(env, policy)
return env, policy
def rollout(env, policy, epochs=1):
for epoch in tqdm.trange(1, epochs+1, desc="Testing"):
done = False
state = env.reset()
while not done:
action = policy(state)
next_state, reward, done, info = env.step(action)
state = next_state
def main(env_name, load_model, dump_path=None, epochs=5, include=None):
if dump_path is None:
date_time = datetime.now().strftime("%d.%m.%Y_%H.%M")
dump_path = os.path.join("test_videos", f"{env_name}-{date_time}")
env, policy = setup(env_name, load_model, dump_path, include)
rollout(env, policy, epochs=epochs)
if __name__ == "__main__":
arg_dict = get_args()
main(**arg_dict)