Skip to content

Commit

Permalink
[Feature] ReplayBuffer.empty (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 7, 2023
1 parent 0961cb3 commit fdd1c8a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 11 deletions.
20 changes: 20 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,26 @@ def test_add(self, rbtype, storage, size, prefetch):
else:
assert (s == data).all()

def test_empty(self, rbtype, storage, size, prefetch):
torch.manual_seed(0)
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
data = self._get_datum(rbtype)
for _ in range(2):
rb.add(data)
s = rb.sample(1)[0]
if isinstance(s, TensorDictBase):
s = s.select(*data.keys(True), strict=False)
data = data.select(*s.keys(True), strict=False)
assert (s == data).all()
assert list(s.keys(True, True))
else:
assert (s == data).all()
rb.empty()
with pytest.raises(
RuntimeError, match="Cannot sample from an empty storage"
):
rb.sample()

def test_extend(self, rbtype, storage, size, prefetch):
torch.manual_seed(0)
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
Expand Down
6 changes: 6 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]:

return data, info

def empty(self):
"""Empties the replay buffer and reset cursor to 0."""
self._writer._empty()
self._sampler._empty()
self._storage._empty()

def sample(
self, batch_size: Optional[int] = None, return_info: bool = False
) -> Any:
Expand Down
35 changes: 31 additions & 4 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from .storages import Storage
from .utils import _to_numpy, INT_CLASSES

_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."


class Sampler(ABC):
"""A generic sampler base class for composable Replay Buffers."""

@abstractmethod
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
raise NotImplementedError
...

def add(self, index: int) -> None:
return
Expand Down Expand Up @@ -57,6 +59,10 @@ def ran_out(self) -> bool:
# by default, samplers never run out
return False

@abstractmethod
def _empty(self):
...


class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.
Expand All @@ -68,9 +74,14 @@ class RandomSampler(Sampler):
"""

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if len(storage) == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
index = torch.randint(0, len(storage), (batch_size,))
return index, {}

def _empty(self):
pass


class SamplerWithoutReplacement(Sampler):
"""A data-consuming sampler that ensures that the same sample is not present in consecutive batches.
Expand Down Expand Up @@ -115,6 +126,8 @@ def _single_sample(self, len_storage, batch_size):

def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
len_storage = len(storage)
if len_storage == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
if not len_storage:
raise RuntimeError("An empty storage was passed")
if self.len_storage != len_storage or self._sample_list is None:
Expand All @@ -141,6 +154,11 @@ def ran_out(self):
def ran_out(self, value):
self._ran_out = value

def _empty(self):
self._sample_list = None
self.len_storage = 0
self._ran_out = False


class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.
Expand Down Expand Up @@ -182,23 +200,32 @@ def __init__(
self._beta = beta
self._eps = eps
self.reduction = reduction
if dtype in (torch.float, torch.FloatType, torch.float32):
self.dtype = dtype
self._init()

def _init(self):
if self.dtype in (torch.float, torch.FloatType, torch.float32):
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
self._min_tree = MinSegmentTreeFp32(self._max_capacity)
elif dtype in (torch.double, torch.DoubleTensor, torch.float64):
elif self.dtype in (torch.double, torch.DoubleTensor, torch.float64):
self._sum_tree = SumSegmentTreeFp64(self._max_capacity)
self._min_tree = MinSegmentTreeFp64(self._max_capacity)
else:
raise NotImplementedError(
f"dtype {dtype} not supported by PrioritizedSampler"
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = 1.0

def _empty(self):
self._init()

@property
def default_priority(self) -> float:
return (self._max_priority + self._eps) ** self._alpha

def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
if len(storage) == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
p_sum = self._sum_tree.query(0, len(storage))
p_min = self._min_tree.query(0, len(storage))
if p_sum <= 0:
Expand Down
24 changes: 19 additions & 5 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __init__(self, max_size: int) -> None:

@abc.abstractmethod
def set(self, cursor: int, data: Any):
raise NotImplementedError
...

@abc.abstractmethod
def get(self, index: int) -> Any:
raise NotImplementedError
...

def attach(self, buffer: Any) -> None:
"""This function attaches a sampler to this storage.
Expand Down Expand Up @@ -80,13 +80,19 @@ def __iter__(self):

@abc.abstractmethod
def __len__(self):
raise NotImplementedError
...

@abc.abstractmethod
def state_dict(self) -> Dict[str, Any]:
raise NotImplementedError
...

@abc.abstractmethod
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
raise NotImplementedError
...

@abc.abstractmethod
def _empty(self):
...


class ListStorage(Storage):
Expand Down Expand Up @@ -157,6 +163,9 @@ def load_state_dict(self, state_dict):
f"Objects of type {type(elt)} are not supported by ListStorage.load_state_dict"
)

def _empty(self):
self._storage = []


class LazyTensorStorage(Storage):
"""A pre-allocated tensor storage for tensors and tensordicts.
Expand Down Expand Up @@ -282,6 +291,11 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
def __len__(self):
return self._len

def _empty(self):
# assuming that the data structure is the same, we don't need to to
# anything if the cursor is reset to 0
self._len = 0


class LazyMemmapStorage(LazyTensorStorage):
"""A memory-mapped storage for tensors and tensordicts.
Expand Down
11 changes: 9 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ def register_storage(self, storage: Storage) -> None:
@abstractmethod
def add(self, data: Any) -> int:
"""Inserts one piece of data at an appropriate index, and returns that index."""
raise NotImplementedError
...

@abstractmethod
def extend(self, data: Sequence) -> torch.Tensor:
"""Inserts a series of data points at appropriate indices, and returns a tensor containing the indices."""
raise NotImplementedError
...

@abstractmethod
def _empty(self):
...

def state_dict(self) -> Dict[str, Any]:
return {}
Expand Down Expand Up @@ -73,6 +77,9 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._cursor = state_dict["_cursor"]

def _empty(self):
self._cursor = 0


class TensorDictRoundRobinWriter(RoundRobinWriter):
"""A RoundRobin Writer class for composable, tensordict-based replay buffers."""
Expand Down

0 comments on commit fdd1c8a

Please sign in to comment.