-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathsimulators.py
88 lines (63 loc) · 3.15 KB
/
simulators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from collections import defaultdict, namedtuple
import numpy as np
import torch
from autoassign import autoassign
from envs.ant_gather import AntGatherEnv
from envs.point_gather import PointGatherEnv
from memory import Memory, Trajectory
from torch_utils.torch_utils import get_device
def make_env(env_name, **env_args):
if env_name == 'ant_gather':
return PointGather(**env_args)
elif env_name == 'point_gather':
return PointGatherEnv(**env_args)
else:
raise NotImplementedError
class Simulator:
@autoassign(exclude=('env_name', 'env_args'))
def __init__(self, env_name, policy, n_trajectories, trajectory_len, obs_filter=None, **env_args):
self.env = np.asarray([make_env(env_name, **env_args) for i in range(n_trajectories)])
self.n_trajectories = n_trajectories
for env in self.env:
env._max_episode_steps = trajectory_len
self.device = get_device()
class SinglePathSimulator:
def __init__(self, env_name, policy, n_trajectories, trajectory_len, state_filter=None,
**env_args):
Simulator.__init__(self, env_name, policy, n_trajectories, trajectory_len, state_filter,
**env_args)
def run_sim(self):
self.policy.eval()
with torch.no_grad():
trajectories = np.asarray([Trajectory() for i in range(self.n_trajectories)])
continue_mask = np.ones(self.n_trajectories)
for env, trajectory in zip(self.env, trajectories):
obs = torch.tensor(env.reset()).float()
# Maybe batch this operation later
if self.obs_filter:
obs = self.obs_filter(obs)
trajectory.observations.append(obs)
while np.any(continue_mask):
continue_indices = np.where(continue_mask)
trajs_to_update = trajectories[continue_indices]
continuing_envs = self.env[continue_indices]
policy_input = torch.stack([torch.tensor(trajectory.observations[-1]).to(self.device)
for trajectory in trajs_to_update])
action_dists = self.policy(policy_input)
actions = action_dists.sample()
actions = actions.cpu()
for env, action, trajectory in zip(continuing_envs, actions, trajs_to_update):
obs, reward, trajectory.done, info = env.step(action.numpy())
obs = torch.tensor(obs).float()
reward = torch.tensor(reward, dtype=torch.float)
cost = torch.tensor(info['constraint_cost'], dtype=torch.float)
if self.obs_filter:
obs = self.obs_filter(obs)
trajectory.actions.append(action)
trajectory.rewards.append(reward)
trajectory.costs.append(cost)
if not trajectory.done:
trajectory.observations.append(obs)
continue_mask = np.asarray([1 - trajectory.done for trajectory in trajectories])
memory = Memory(trajectories)
return memory