Skip to content

Commit

Permalink
[Feature] Add support for null dim argument in TensorDict.squeeze (
Browse files Browse the repository at this point in the history
…#608)

* Start adding support for null `dim` argument in `TensorDict.squeeze`

* Address PR comments re: style

* Fully implement and test `TensorDict.squeeze` with `dim=None`

* Fix linting errors
  • Loading branch information
jgonik authored Oct 31, 2022
1 parent 86380c8 commit 469c871
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
12 changes: 12 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,18 @@ def test_squeeze(self, td_name, device, squeeze_dim=-1):
assert (td_squeeze.get("a") == 1).all()
assert (td.get("a") == 1).all()

def test_squeeze_with_none(self, td_name, device, squeeze_dim=None):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
td_squeeze = torch.squeeze(td, dim=None)
tensor = torch.ones_like(td.get("a").squeeze())
td_squeeze.set_("a", tensor)
assert (td_squeeze.get("a") == tensor).all()
if td_name == "unsqueezed_td":
assert td_squeeze._source is td
assert (td_squeeze.get("a") == 1).all()
assert (td.get("a") == 1).all()

def test_write_on_subtd(self, td_name, device):
td = getattr(self, td_name)(device)
sub_td = td.get_sub_tensordict(0)
Expand Down
23 changes: 19 additions & 4 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,13 +1196,28 @@ def unsqueeze(self, dim: int) -> TensorDictBase:
inv_op_kwargs={"dim": dim},
)

def squeeze(self, dim: int) -> TensorDictBase:
def squeeze(self, dim: Optional[int] = None) -> TensorDictBase:
"""Squeezes all tensors for a dimension comprised in between `-td.batch_dims+1` and `td.batch_dims-1` and returns them in a new tensordict.
Args:
dim (int): dimension along which to squeeze
dim (Optional[int]): dimension along which to squeeze. If dim is None, all singleton dimensions will be squeezed. dim is None by default.
"""
if dim is None:
size = self.size()
if len(self.size()) == 1 or size.count(1) == 0:
return self
first_singleton_dim = size.index(1)

squeezed_dict = SqueezedTensorDict(
source=self,
custom_op="squeeze",
inv_op="unsqueeze",
custom_op_kwargs={"dim": first_singleton_dim},
inv_op_kwargs={"dim": first_singleton_dim},
)
return squeezed_dict.squeeze(dim=None)

if dim < 0:
dim = self.batch_dims + dim

Expand Down Expand Up @@ -4489,8 +4504,8 @@ class UnsqueezedTensorDict(_CustomOpTensorDict):
True
"""

def squeeze(self, dim: int) -> TensorDictBase:
if dim < 0:
def squeeze(self, dim: Optional[int]) -> TensorDictBase:
if dim is not None and dim < 0:
dim = self.batch_dims + dim
if dim == self.custom_op_kwargs.get("dim"):
return self._source
Expand Down

0 comments on commit 469c871

Please sign in to comment.