@@ -10837,12 +10837,16 @@ class Timer(Transform):
10837
10837
10838
10838
Attributes:
10839
10839
out_keys: The keys of the output tensordict for the inverse transform. Defaults to
10840
- `out_keys = [f"{time_key}_step", f"{time_key}_policy"]`, where the first key represents
10840
+ `out_keys = [f"{time_key}_step", f"{time_key}_policy", f"{time_key}_reset" ]`, where the first key represents
10841
10841
the time it takes to make a step in the environment, and the second key represents the
10842
- time it takes to execute the policy.
10842
+ time it takes to execute the policy, the third the time for the call to `reset` .
10843
10843
time_key: A prefix for the keys where the time intervals will be stored in the tensordict.
10844
10844
Defaults to `"time"`.
10845
10845
10846
+ .. note:: During a succession of rollouts, the time marks of the reset are written at the root (the `"time_reset"`
10847
+ entry or equivalent key is always 0 in the `"next"` tensordict). At the root, the `"time_policy"` and `"time_step"`
10848
+ entries will be 0 when there is a reset. they will never be `0` in the `"next"`.
10849
+
10846
10850
Examples:
10847
10851
>>> from torchrl.envs import Timer, GymEnv
10848
10852
>>>
@@ -10854,20 +10858,23 @@ class Timer(Transform):
10854
10858
>>> print("time for step", r["time_step"])
10855
10859
time for step tensor([9.5797e-04, 1.6289e-03, 9.7990e-05, 8.0824e-05, 9.0837e-05, 7.6056e-05,
10856
10860
8.2016e-05, 7.6056e-05, 8.1062e-05, 7.7009e-05])
10861
+
10862
+
10857
10863
"""
10858
10864
10859
10865
def __init__ (self , out_keys : Sequence [NestedKey ] = None , time_key : str = "time" ):
10860
10866
if out_keys is None :
10861
- out_keys = [f"{ time_key } _step" , f"{ time_key } _policy" ]
10862
- elif len (out_keys ) != 2 :
10863
- raise TypeError (f"Expected two out_keys. Got out_keys={ out_keys } ." )
10867
+ out_keys = [f"{ time_key } _step" , f"{ time_key } _policy" , f" { time_key } _reset" ]
10868
+ elif len (out_keys ) != 3 :
10869
+ raise TypeError (f"Expected three out_keys. Got out_keys={ out_keys } ." )
10864
10870
super ().__init__ ([], out_keys )
10865
10871
self .time_key = time_key
10866
10872
self .last_inv_time = None
10867
10873
self .last_call_time = None
10874
+ self .last_reset_time = None
10868
10875
10869
10876
def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
10870
- self .last_inv_time = time .time ()
10877
+ self .last_reset_time = self . last_inv_time = time .time ()
10871
10878
return tensordict
10872
10879
10873
10880
def _maybe_expand_and_set (self , key , time_elapsed , tensordict ):
@@ -10888,11 +10895,14 @@ def _reset(
10888
10895
self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
10889
10896
) -> TensorDictBase :
10890
10897
current_time = time .time ()
10891
- if self .last_inv_time is not None :
10898
+ if self .last_reset_time is not None :
10892
10899
time_elapsed = torch .tensor (
10893
- current_time - self .last_inv_time , device = tensordict .device
10900
+ current_time - self .last_reset_time , device = tensordict .device
10901
+ )
10902
+ self ._maybe_expand_and_set (self .out_keys [2 ], time_elapsed , tensordict_reset )
10903
+ self ._maybe_expand_and_set (
10904
+ self .out_keys [0 ], time_elapsed * 0 , tensordict_reset
10894
10905
)
10895
- self ._maybe_expand_and_set (self .out_keys [0 ], time_elapsed , tensordict_reset )
10896
10906
self .last_call_time = current_time
10897
10907
# Placeholder
10898
10908
self ._maybe_expand_and_set (self .out_keys [1 ], time_elapsed * 0 , tensordict_reset )
@@ -10917,6 +10927,9 @@ def _step(
10917
10927
current_time - self .last_inv_time , device = tensordict .device
10918
10928
)
10919
10929
self ._maybe_expand_and_set (self .out_keys [0 ], time_elapsed , next_tensordict )
10930
+ self ._maybe_expand_and_set (
10931
+ self .out_keys [2 ], time_elapsed * 0 , next_tensordict
10932
+ )
10920
10933
self .last_call_time = current_time
10921
10934
# presumbly no need to worry about batch size incongruencies here
10922
10935
next_tensordict .set (self .out_keys [1 ], tensordict .get (self .out_keys [1 ]))
@@ -10929,6 +10942,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
10929
10942
observation_spec [self .out_keys [1 ]] = Unbounded (
10930
10943
shape = observation_spec .shape , device = observation_spec .device
10931
10944
)
10945
+ observation_spec [self .out_keys [2 ]] = Unbounded (
10946
+ shape = observation_spec .shape , device = observation_spec .device
10947
+ )
10932
10948
return observation_spec
10933
10949
10934
10950
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
0 commit comments