diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0f5d65a9487..b88e784fade 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1093,6 +1093,18 @@ def test_squeeze(self, td_name, device, squeeze_dim=-1): assert (td_squeeze.get("a") == 1).all() assert (td.get("a") == 1).all() + def test_squeeze_with_none(self, td_name, device, squeeze_dim=None): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td_squeeze = torch.squeeze(td, dim=None) + tensor = torch.ones_like(td.get("a").squeeze()) + td_squeeze.set_("a", tensor) + assert (td_squeeze.get("a") == tensor).all() + if td_name == "unsqueezed_td": + assert td_squeeze._source is td + assert (td_squeeze.get("a") == 1).all() + assert (td.get("a") == 1).all() + def test_write_on_subtd(self, td_name, device): td = getattr(self, td_name)(device) sub_td = td.get_sub_tensordict(0) diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index efcb1eb5fca..3a009887a5a 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -1196,13 +1196,28 @@ def unsqueeze(self, dim: int) -> TensorDictBase: inv_op_kwargs={"dim": dim}, ) - def squeeze(self, dim: int) -> TensorDictBase: + def squeeze(self, dim: Optional[int] = None) -> TensorDictBase: """Squeezes all tensors for a dimension comprised in between `-td.batch_dims+1` and `td.batch_dims-1` and returns them in a new tensordict. Args: - dim (int): dimension along which to squeeze + dim (Optional[int]): dimension along which to squeeze. If dim is None, all singleton dimensions will be squeezed. dim is None by default. """ + if dim is None: + size = self.size() + if len(self.size()) == 1 or size.count(1) == 0: + return self + first_singleton_dim = size.index(1) + + squeezed_dict = SqueezedTensorDict( + source=self, + custom_op="squeeze", + inv_op="unsqueeze", + custom_op_kwargs={"dim": first_singleton_dim}, + inv_op_kwargs={"dim": first_singleton_dim}, + ) + return squeezed_dict.squeeze(dim=None) + if dim < 0: dim = self.batch_dims + dim @@ -4489,8 +4504,8 @@ class UnsqueezedTensorDict(_CustomOpTensorDict): True """ - def squeeze(self, dim: int) -> TensorDictBase: - if dim < 0: + def squeeze(self, dim: Optional[int]) -> TensorDictBase: + if dim is not None and dim < 0: dim = self.batch_dims + dim if dim == self.custom_op_kwargs.get("dim"): return self._source