Skip to content

Commit 8dcf987

Browse files
louisfauryLouis Faury
andauthored
[BugFix] Categorical spec samples the right dtype when masked (#2980)
Co-authored-by: Louis Faury <[email protected]>
1 parent 74d6cbc commit 8dcf987

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/test_specs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,14 @@ def test_categorical_action_spec_rand(self):
13311331
sample = [sum(sample == i) for i in range(10)]
13321332
assert chisquare(sample).pvalue > 0.1
13331333

1334+
@pytest.mark.parametrize("dtype", [torch.int, torch.int32, torch.int64])
1335+
def test_categorical_action_spec_rand_masked_right_dtype(self, dtype: torch.dtype):
1336+
torch.manual_seed(1)
1337+
action_spec = Categorical(2, dtype=dtype)
1338+
action_spec.update_mask(torch.tensor([True, False]))
1339+
sample = action_spec.rand()
1340+
assert sample.dtype == dtype
1341+
13341342
def test_mult_discrete_action_spec_rand(self):
13351343
torch.manual_seed(0)
13361344
ns = (10, 5)

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3926,7 +3926,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
39263926
"The last dimension of the mask must match the number of action allowed by the "
39273927
f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}."
39283928
)
3929-
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
3929+
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out).to(self.dtype)
39303930
return out
39313931

39323932
def index(

0 commit comments

Comments
 (0)