Skip to content

Commit 0bd30eb

Browse files
committed
[Performance] Hold a single copy of low/high in bounded specs
ghstack-source-id: d6d64e3 Pull-Request-resolved: #2977
1 parent 72c60fb commit 0bd30eb

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

test/test_specs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3999,6 +3999,55 @@ def test_device_ordinal():
39993999
assert spec.device == torch.device("cuda:0")
40004000

40014001

4002+
class TestBatchSizeBox:
4003+
def test_batch_size_box_same(self):
4004+
spec = Bounded(shape=(10, 2), low=-1, high=1, device=torch.device("cpu"))
4005+
spec.shape = (10, 2)
4006+
assert spec.shape == (10, 2)
4007+
assert spec.space.low.shape == (10, 2)
4008+
assert spec.space.high.shape == (10, 2)
4009+
assert spec.space._low.shape == (10, 2)
4010+
assert spec.space._high.shape == (10, 2)
4011+
c_spec = Composite(b=spec, shape=(10,))
4012+
assert spec.shape == (10, 2)
4013+
assert spec.space.low.shape == (10, 2)
4014+
assert spec.space.high.shape == (10, 2)
4015+
assert spec.space._low.shape == (2,)
4016+
assert spec.space._high.shape == (2,)
4017+
c_spec = Composite(b=spec, shape=(10, 2))
4018+
assert spec.shape == (10, 2)
4019+
assert spec.space.low.shape == (10, 2)
4020+
assert spec.space.high.shape == (10, 2)
4021+
assert spec.space._low.shape == ()
4022+
assert spec.space._high.shape == ()
4023+
c_spec = Composite(b=spec, shape=())
4024+
assert spec.shape == (10, 2)
4025+
assert spec.space.low.shape == (10, 2)
4026+
assert spec.space.high.shape == (10, 2)
4027+
assert spec.space._low.shape == (10, 2)
4028+
assert spec.space._high.shape == (10, 2)
4029+
4030+
def test_batch_size_box_diff(self):
4031+
spec = Bounded(
4032+
shape=(10, 2),
4033+
low=-torch.arange(20).view(10, 2),
4034+
high=torch.arange(20).view(10, 2),
4035+
device=torch.device("cpu"),
4036+
)
4037+
spec.shape = (10, 2)
4038+
assert spec.shape == (10, 2)
4039+
assert spec.space.low.shape == (10, 2)
4040+
assert spec.space.high.shape == (10, 2)
4041+
assert spec.space._low.shape == (10, 2)
4042+
assert spec.space._high.shape == (10, 2)
4043+
c_spec = Composite(b=spec, shape=(10,))
4044+
assert spec.shape == (10, 2)
4045+
assert spec.space.low.shape == (10, 2)
4046+
assert spec.space.high.shape == (10, 2)
4047+
assert spec.space._low.shape == (10, 2)
4048+
assert spec.space._high.shape == (10, 2)
4049+
4050+
40024051
class TestLegacy:
40034052
def test_one_hot(self):
40044053
with pytest.warns(

torchrl/data/tensor_specs.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,20 +391,61 @@ class ContinuousBox(Box):
391391
_low: torch.Tensor
392392
_high: torch.Tensor
393393
device: torch.device | None = None
394+
_batch_size: torch.Size | None = None
395+
396+
@property
397+
def batch_size(self):
398+
return self._batch_size
399+
400+
@batch_size.setter
401+
def batch_size(self, value: torch.Size | tuple):
402+
# Check batch size is compatible with low and high
403+
value = _remove_neg_shapes(value)
404+
if self._batch_size is None:
405+
if value != self._low.shape[: len(value)]:
406+
raise ValueError(
407+
f"Batch size {value} is not compatible with low and high {self._low.shape}"
408+
)
409+
if value is None:
410+
self._batch_size = None
411+
self._low = self.low.clone()
412+
self._high = self.high.clone()
413+
return
414+
# Remove batch size from low and high
415+
if value:
416+
# Check that low and high have a single value
417+
td_low_high = TensorDict(
418+
low=self.low, high=self.high, batch_size=value
419+
).flatten()
420+
td_low_high0 = td_low_high[0]
421+
if torch.allclose(
422+
td_low_high0["low"], td_low_high["low"]
423+
) and torch.allclose(td_low_high0["high"], td_low_high["high"]):
424+
self._low = td_low_high0["low"].clone()
425+
self._high = td_low_high0["high"].clone()
426+
self._batch_size = torch.Size(value)
427+
else:
428+
self._low = self.low.clone()
429+
self._high = self.high.clone()
430+
self._batch_size = torch.Size(value)
394431

395432
# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
396433
@property
397434
def low(self):
398435
low = self._low
399436
if self.device is not None and low.device != self.device:
400437
low = low.to(self.device)
438+
if self._batch_size:
439+
low = low.expand((*self._batch_size, *low.shape)).clone()
401440
return low
402441

403442
@property
404443
def high(self):
405444
high = self._high
406445
if self.device is not None and high.device != self.device:
407446
high = high.to(self.device)
447+
if self._batch_size:
448+
high = high.expand((*self._batch_size, *high.shape)).clone()
408449
return high
409450

410451
def unbind(self, dim: int = 0):
@@ -417,15 +458,30 @@ def unbind(self, dim: int = 0):
417458
def low(self, value):
418459
self.device = value.device
419460
self._low = value
461+
if self._batch_size is not None:
462+
if value.shape[: len(self._batch_size)] != self._batch_size:
463+
raise ValueError(
464+
f"Batch size {value.shape[:len(self._batch_size)]} is not compatible with low and high {self._batch_size}"
465+
)
466+
if self._batch_size:
467+
self._low = self._low.flatten(0, len(self._batch_size) - 1)[0].clone()
420468

421469
@high.setter
422470
def high(self, value):
423471
self.device = value.device
424472
self._high = value
473+
if self._batch_size is not None:
474+
if value.shape[: len(self._batch_size)] != self._batch_size:
475+
raise ValueError(
476+
f"Batch size {value.shape[:len(self._batch_size)]} is not compatible with low and high {self._batch_size}"
477+
)
478+
if self._batch_size:
479+
self._high = self._high.flatten(0, len(self._batch_size) - 1)[0].clone()
425480

426481
def __post_init__(self):
427482
self.low = self.low.clone()
428483
self.high = self.high.clone()
484+
self._batch_size = None
429485

430486
def __iter__(self):
431487
yield self.low
@@ -2366,6 +2422,10 @@ def __init__(
23662422
)
23672423
self.encode = self._encode_eager
23682424

2425+
def _register_batch_size(self, batch_size: torch.Size | tuple):
2426+
# Register batch size in the space to decrease the memory footprint of the specs
2427+
self.space.batch_size = batch_size
2428+
23692429
def index(
23702430
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
23712431
) -> torch.Tensor | TensorDictBase:
@@ -5191,6 +5251,8 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
51915251
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
51925252
f"Composite.shape={self.shape}."
51935253
)
5254+
if isinstance(spec, Bounded):
5255+
spec._register_batch_size(self.shape)
51945256
self._specs[name] = spec
51955257
return self
51965258

0 commit comments

Comments
 (0)