-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Separates per-step termination and last-episode termination bookkeeping #3745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fd5d8cf
8f69051
e020d70
d60a875
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,8 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv): | |
| self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)} | ||
| # prepare extra info to store individual termination term information | ||
| self._term_dones = torch.zeros((self.num_envs, len(self._term_names)), device=self.device, dtype=torch.bool) | ||
| # prepare extra info to store last episode done per termination term information | ||
| self._last_episode_dones = torch.zeros_like(self._term_dones) | ||
| # create buffer for managing termination per environment | ||
| self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) | ||
| self._terminated_buf = torch.zeros_like(self._truncated_buf) | ||
|
|
@@ -138,7 +140,7 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] | |
| env_ids = slice(None) | ||
| # add to episode dict | ||
| extras = {} | ||
| last_episode_done_stats = self._term_dones.float().mean(dim=0) | ||
| last_episode_done_stats = self._last_episode_dones.float().mean(dim=0) | ||
| for i, key in enumerate(self._term_names): | ||
| # store information | ||
| extras["Episode_Termination/" + key] = last_episode_done_stats[i].item() | ||
|
|
@@ -169,15 +171,17 @@ def compute(self) -> torch.Tensor: | |
| else: | ||
| self._terminated_buf |= value | ||
| # add to episode dones | ||
| rows = value.nonzero(as_tuple=True)[0] # indexing is cheaper than boolean advance indexing | ||
| if rows.numel() > 0: | ||
| self._term_dones[rows] = False | ||
| self._term_dones[rows, i] = True | ||
| self._term_dones[:, i] = value | ||
| # update last-episode dones once per compute: for any env where a term fired, | ||
| # reflect exactly which term(s) fired this step and clear others | ||
| rows = self._term_dones.any(dim=1).nonzero(as_tuple=True)[0] | ||
| if rows.numel() > 0: | ||
| self._last_episode_dones[rows] = self._term_dones[rows] | ||
|
Comment on lines
+177
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Previous comment about overwriting is actually describing intended behavior. When term A fires at step 3 then term B fires at step 5 before reset, |
||
| # return combined termination signal | ||
| return self._truncated_buf | self._terminated_buf | ||
|
|
||
| def get_term(self, name: str) -> torch.Tensor: | ||
| """Returns the termination term with the specified name. | ||
| """Returns the termination term value at current step with the specified name. | ||
|
|
||
| Args: | ||
| name: The name of the termination term. | ||
|
|
@@ -190,7 +194,8 @@ def get_term(self, name: str) -> torch.Tensor: | |
| def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: | ||
| """Returns the active terms as iterable sequence of tuples. | ||
|
|
||
| The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. | ||
| The first element of the tuple is the name of the term and the second element is the raw value(s) of the term | ||
| recorded at current step. | ||
|
|
||
| Args: | ||
| env_idx: The specific environment to pull the active terms from. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Launch Isaac Sim Simulator first.""" | ||
|
|
||
| from isaaclab.app import AppLauncher | ||
|
|
||
| # launch omniverse app | ||
| simulation_app = AppLauncher(headless=True).app | ||
|
|
||
| """Rest everything follows.""" | ||
|
|
||
| import torch | ||
|
|
||
| import pytest | ||
|
|
||
| from isaaclab.managers import TerminationManager, TerminationTermCfg | ||
| from isaaclab.sim import SimulationContext | ||
|
|
||
|
|
||
| class DummyEnv: | ||
| """Minimal mutable env stub for the termination manager tests.""" | ||
|
|
||
| def __init__(self, num_envs: int, device: str, sim: SimulationContext): | ||
| self.num_envs = num_envs | ||
| self.device = device | ||
| self.sim = sim | ||
| self.counter = 0 # mutable step counter used by test terms | ||
|
|
||
|
|
||
| def fail_every_5_steps(env) -> torch.Tensor: | ||
| """Returns True for all envs when counter is a positive multiple of 5.""" | ||
| cond = env.counter > 0 and (env.counter % 5 == 0) | ||
| return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device) | ||
|
|
||
|
|
||
| def fail_every_10_steps(env) -> torch.Tensor: | ||
| """Returns True for all envs when counter is a positive multiple of 10.""" | ||
| cond = env.counter > 0 and (env.counter % 10 == 0) | ||
| return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device) | ||
|
|
||
|
|
||
| def fail_every_3_steps(env) -> torch.Tensor: | ||
| """Returns True for all envs when counter is a positive multiple of 3.""" | ||
| cond = env.counter > 0 and (env.counter % 3 == 0) | ||
| return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def env(): | ||
| sim = SimulationContext() | ||
| return DummyEnv(num_envs=20, device="cpu", sim=sim) | ||
|
|
||
|
|
||
| def test_initial_state_and_shapes(env): | ||
| cfg = { | ||
| "term_5": TerminationTermCfg(func=fail_every_5_steps), | ||
| "term_10": TerminationTermCfg(func=fail_every_10_steps), | ||
| } | ||
| tm = TerminationManager(cfg, env) | ||
|
|
||
| # Active term names | ||
| assert tm.active_terms == ["term_5", "term_10"] | ||
|
|
||
| # Internal buffers have expected shapes and start as all False | ||
| assert tm._term_dones.shape == (env.num_envs, 2) | ||
| assert tm._last_episode_dones.shape == (env.num_envs, 2) | ||
| assert tm.dones.shape == (env.num_envs,) | ||
| assert tm.time_outs.shape == (env.num_envs,) | ||
| assert tm.terminated.shape == (env.num_envs,) | ||
| assert torch.all(~tm._term_dones) and torch.all(~tm._last_episode_dones) | ||
|
|
||
|
|
||
| def test_term_transitions_and_persistence(env): | ||
| """Concise transitions: single fire, persist, switch, both, persist. | ||
|
|
||
| Uses 3-step and 5-step terms and verifies current-step values and last-episode persistence. | ||
| """ | ||
| cfg = { | ||
| "term_3": TerminationTermCfg(func=fail_every_3_steps, time_out=False), | ||
| "term_5": TerminationTermCfg(func=fail_every_5_steps, time_out=False), | ||
| } | ||
| tm = TerminationManager(cfg, env) | ||
|
|
||
| # step 3: only term_3 -> last_episode [True, False] | ||
| env.counter = 3 | ||
| out = tm.compute() | ||
| assert torch.all(tm.get_term("term_3")) and torch.all(~tm.get_term("term_5")) | ||
| assert torch.all(out) | ||
| assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(~tm._last_episode_dones[:, 1]) | ||
|
|
||
| # step 4: none -> last_episode persists [True, False] | ||
| env.counter = 4 | ||
| out = tm.compute() | ||
| assert torch.all(~out) | ||
| assert torch.all(~tm.get_term("term_3")) and torch.all(~tm.get_term("term_5")) | ||
| assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(~tm._last_episode_dones[:, 1]) | ||
|
|
||
| # step 5: only term_5 -> last_episode [False, True] | ||
| env.counter = 5 | ||
| out = tm.compute() | ||
| assert torch.all(~tm.get_term("term_3")) and torch.all(tm.get_term("term_5")) | ||
| assert torch.all(out) | ||
| assert torch.all(~tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1]) | ||
|
|
||
| # step 15: both -> last_episode [True, True] | ||
| env.counter = 15 | ||
| out = tm.compute() | ||
| assert torch.all(tm.get_term("term_3")) and torch.all(tm.get_term("term_5")) | ||
| assert torch.all(out) | ||
| assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1]) | ||
|
|
||
| # step 16: none -> persist [True, True] | ||
| env.counter = 16 | ||
| out = tm.compute() | ||
| assert torch.all(~out) | ||
| assert torch.all(~tm.get_term("term_3")) and torch.all(~tm.get_term("term_5")) | ||
| assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1]) | ||
|
|
||
|
|
||
| def test_time_out_vs_terminated_split(env): | ||
| cfg = { | ||
| "term_5": TerminationTermCfg(func=fail_every_5_steps, time_out=False), # terminated | ||
| "term_10": TerminationTermCfg(func=fail_every_10_steps, time_out=True), # timeout | ||
| } | ||
| tm = TerminationManager(cfg, env) | ||
|
|
||
| # Step 5: terminated fires, not timeout | ||
| env.counter = 5 | ||
| out = tm.compute() | ||
| assert torch.all(out) | ||
| assert torch.all(tm.terminated) and torch.all(~tm.time_outs) | ||
|
|
||
| # Step 10: both fire; timeout and terminated both True | ||
| env.counter = 10 | ||
| out = tm.compute() | ||
| assert torch.all(out) | ||
| assert torch.all(tm.terminated) and torch.all(tm.time_outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: When multiple terms fire for the same environment in one step,
self._last_episode_dones[rows] = self._term_dones[rows]correctly captures all firing terms. However, when a different term fires in a subsequent step for the same environment before reset, this overwrites the previous term's record. For example:_last_episode_dones[0] = [True, False]_last_episode_dones[0] = [False, True]This causes the logging to miss the original termination cause.