Skip to content

Commit 5a46379

Browse files
author
Vincent Moens
committed
[Feature] reset_time in Timer
ghstack-source-id: 36a74fd Pull Request resolved: #2807
1 parent 104b880 commit 5a46379

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

test/test_transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13949,7 +13949,15 @@ def test_transform_env(self):
1394913949
# The stack must be contiguous
1395013950
assert not isinstance(rollout, LazyStackedTensorDict)
1395113951
assert (rollout["time_policy"] >= 0).all()
13952-
assert (rollout["time_step"] > 0).all()
13952+
assert (rollout["time_step"] >= 0).all()
13953+
env.append_transform(StepCounter(max_steps=5))
13954+
rollout = env.rollout(10, break_when_any_done=False)
13955+
assert (rollout["time_reset"] > 0).sum() == 2
13956+
assert (rollout["time_policy"] == 0).sum() == 2
13957+
assert (rollout["time_step"] == 0).sum() == 2
13958+
assert (rollout["next", "time_reset"] == 0).all()
13959+
assert (rollout["next", "time_policy"] > 0).all()
13960+
assert (rollout["next", "time_step"] > 0).all()
1395313961

1395413962
def test_transform_model(self):
1395513963
torch.manual_seed(0)

torchrl/envs/transforms/transforms.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10837,12 +10837,16 @@ class Timer(Transform):
1083710837
1083810838
Attributes:
1083910839
out_keys: The keys of the output tensordict for the inverse transform. Defaults to
10840-
`out_keys = [f"{time_key}_step", f"{time_key}_policy"]`, where the first key represents
10840+
`out_keys = [f"{time_key}_step", f"{time_key}_policy", f"{time_key}_reset"]`, where the first key represents
1084110841
the time it takes to make a step in the environment, and the second key represents the
10842-
time it takes to execute the policy.
10842+
time it takes to execute the policy, the third the time for the call to `reset`.
1084310843
time_key: A prefix for the keys where the time intervals will be stored in the tensordict.
1084410844
Defaults to `"time"`.
1084510845
10846+
.. note:: During a succession of rollouts, the time marks of the reset are written at the root (the `"time_reset"`
10847+
entry or equivalent key is always 0 in the `"next"` tensordict). At the root, the `"time_policy"` and `"time_step"`
10848+
entries will be 0 when there is a reset. they will never be `0` in the `"next"`.
10849+
1084610850
Examples:
1084710851
>>> from torchrl.envs import Timer, GymEnv
1084810852
>>>
@@ -10854,20 +10858,23 @@ class Timer(Transform):
1085410858
>>> print("time for step", r["time_step"])
1085510859
time for step tensor([9.5797e-04, 1.6289e-03, 9.7990e-05, 8.0824e-05, 9.0837e-05, 7.6056e-05,
1085610860
8.2016e-05, 7.6056e-05, 8.1062e-05, 7.7009e-05])
10861+
10862+
1085710863
"""
1085810864

1085910865
def __init__(self, out_keys: Sequence[NestedKey] = None, time_key: str = "time"):
1086010866
if out_keys is None:
10861-
out_keys = [f"{time_key}_step", f"{time_key}_policy"]
10862-
elif len(out_keys) != 2:
10863-
raise TypeError(f"Expected two out_keys. Got out_keys={out_keys}.")
10867+
out_keys = [f"{time_key}_step", f"{time_key}_policy", f"{time_key}_reset"]
10868+
elif len(out_keys) != 3:
10869+
raise TypeError(f"Expected three out_keys. Got out_keys={out_keys}.")
1086410870
super().__init__([], out_keys)
1086510871
self.time_key = time_key
1086610872
self.last_inv_time = None
1086710873
self.last_call_time = None
10874+
self.last_reset_time = None
1086810875

1086910876
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
10870-
self.last_inv_time = time.time()
10877+
self.last_reset_time = self.last_inv_time = time.time()
1087110878
return tensordict
1087210879

1087310880
def _maybe_expand_and_set(self, key, time_elapsed, tensordict):
@@ -10888,11 +10895,14 @@ def _reset(
1088810895
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
1088910896
) -> TensorDictBase:
1089010897
current_time = time.time()
10891-
if self.last_inv_time is not None:
10898+
if self.last_reset_time is not None:
1089210899
time_elapsed = torch.tensor(
10893-
current_time - self.last_inv_time, device=tensordict.device
10900+
current_time - self.last_reset_time, device=tensordict.device
10901+
)
10902+
self._maybe_expand_and_set(self.out_keys[2], time_elapsed, tensordict_reset)
10903+
self._maybe_expand_and_set(
10904+
self.out_keys[0], time_elapsed * 0, tensordict_reset
1089410905
)
10895-
self._maybe_expand_and_set(self.out_keys[0], time_elapsed, tensordict_reset)
1089610906
self.last_call_time = current_time
1089710907
# Placeholder
1089810908
self._maybe_expand_and_set(self.out_keys[1], time_elapsed * 0, tensordict_reset)
@@ -10917,6 +10927,9 @@ def _step(
1091710927
current_time - self.last_inv_time, device=tensordict.device
1091810928
)
1091910929
self._maybe_expand_and_set(self.out_keys[0], time_elapsed, next_tensordict)
10930+
self._maybe_expand_and_set(
10931+
self.out_keys[2], time_elapsed * 0, next_tensordict
10932+
)
1092010933
self.last_call_time = current_time
1092110934
# presumbly no need to worry about batch size incongruencies here
1092210935
next_tensordict.set(self.out_keys[1], tensordict.get(self.out_keys[1]))
@@ -10929,6 +10942,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
1092910942
observation_spec[self.out_keys[1]] = Unbounded(
1093010943
shape=observation_spec.shape, device=observation_spec.device
1093110944
)
10945+
observation_spec[self.out_keys[2]] = Unbounded(
10946+
shape=observation_spec.shape, device=observation_spec.device
10947+
)
1093210948
return observation_spec
1093310949

1093410950
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

0 commit comments

Comments
 (0)