-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmemory.py
30 lines (22 loc) · 923 Bytes
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from itertools import chain
import torch
class Trajectory:
def __init__(self):
self.observations = []
self.actions = []
self.rewards = []
self.costs = []
self.done = False
def __len__(self):
return len(self.observations)
class Memory:
def __init__(self, trajectories):
self.trajectories = trajectories
def sample(self):
observations = torch.cat([torch.stack(trajectory.observations) for trajectory in self.trajectories])
actions = torch.cat([torch.stack(trajectory.actions) for trajectory in self.trajectories])
rewards = torch.cat([torch.tensor(trajectory.rewards) for trajectory in self.trajectories])
costs = torch.cat([torch.tensor(trajectory.costs) for trajectory in self.trajectories])
return observations, actions, rewards, costs
def __getitem__(self, i):
return self.trajectories[i]