From 6bd92965f540cc621eb3ab8a1cfcc056537529b7 Mon Sep 17 00:00:00 2001 From: Remi Date: Wed, 7 Feb 2024 17:40:11 +0100 Subject: [PATCH] [BugFix] Use traj_terminated in SliceSampler (#1884) --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 5e9b6dd75be..21245e37acd 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -917,7 +917,7 @@ def _sample_slices( truncated[seq_length.cumsum(0) - 1] = 1 traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 terminated = torch.zeros_like(truncated) - if terminated.any(): + if traj_terminated.any(): if isinstance(seq_length, int): truncated.view(num_slices, -1)[traj_terminated] = 1 else: