-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_a3c.py
53 lines (35 loc) · 1.57 KB
/
evaluate_a3c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import gym
import torch
import environment
from models import A3C
from utils.helper import evaluate_A3C_lstm
def get_args():
ap = argparse.ArgumentParser()
ap.add_argument("-e", "--environment", default="BreakoutNoFrameskip-v4", help="envirement to play")
ap.add_argument("-c", "--checkpoint", required=True, help="checkpoint for agent")
ap.add_argument("-v", "--video", default="videos", help="videos_dir")
ap.add_argument("--lstm", action='store_true', help="Enable LSTM")
ap.add_argument("--device", default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
help="Device for training")
opt = ap.parse_args()
return opt
def main():
opt = get_args()
assert opt.environment in environment.ENV_DICT.keys(), \
"Unsupported environment: {} \nSupported environemts: {}".format(opt.environment, environment.ENV_DICT.keys())
device = opt.device
evaluate = evaluate_A3C_lstm
ENV = environment.ENV_DICT[opt.environment]
env = ENV.make_env(clip_rewards=False)
state_shape = env.observation_space.shape
n_actions = env.action_space.n
agent = A3C(n_actions=n_actions, lstm=opt.lstm).to(device)
agent.load_state_dict(torch.load(opt.checkpoint, map_location=torch.device(device)))
agent = agent.eval()
env_monitor = gym.wrappers.Monitor(ENV.make_env(clip_rewards=False, lstm=opt.lstm), directory=opt.video, force=True)
reward = evaluate(env_monitor, agent)
print("Reward: {}".format(reward))
env_monitor.close()
if __name__ == '__main__':
main()