-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathEval.py
55 lines (53 loc) · 1.48 KB
/
Eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Evaluate an agent based on average number of
steps to finish an environment
"""
from tqdm import tqdm
import os
import os.path as osp
import pickle
import numpy as np
import gym
import sys
import time
from rlpyt.agents.pg.categorical import CategoricalPgAgent
from Network import *
from collections import namedtuple
def simulateAgentFile (agentFile, render=False) :
""" Load rlpyt agent from file and simulate """
state_dict = torch.load(
agentFile,
map_location=torch.device('cpu'))
agent = CategoricalPgAgent(AcrobotNet)
env = gym.make('Acrobot-v1')
EnvSpace = namedtuple('EnvSpace', ['action', 'observation'])
agent.initialize(EnvSpace(env.action_space, env.observation_space))
agent.load_state_dict(state_dict)
simulateAgent(agent, render)
def simulateAgent (agent, render=False) :
"""
Simulate agent on environment till the task
is over and return the number of steps taken
"""
env = gym.make('Acrobot-v1')
done = False
trajectory = []
s = torch.tensor(env.reset()).float()
a = torch.tensor(0)
r = torch.tensor(0).float()
i = 0
while not done :
i += 1
if render:
env.render()
time.sleep(0.05)
a = agent.step(s, a, r).action
s_, r, done, info = env.step(a.item())
s_ = torch.tensor(s_).float()
r = torch.tensor(r).float()
s = s_
if render:
env.render()
time.sleep(0.05)
env.close()
return i