diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 5e9b6dd75be..21245e37acd 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -917,7 +917,7 @@ def _sample_slices( truncated[seq_length.cumsum(0) - 1] = 1 traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 terminated = torch.zeros_like(truncated) - if terminated.any(): + if traj_terminated.any(): if isinstance(seq_length, int): truncated.view(num_slices, -1)[traj_terminated] = 1 else: