You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using torchrl.collectors.MultiSyncDataCollector with the following setup:
reset_at_each_iter=True
cat_results=-1
frames_per_batch=N
My goal is to filter each collected batch to retain only complete trajectories, i.e., episodes that start with from_step_count == 0 and end with done == True within the same batch.
The motivation behind is: I calculate the rewards at the end of each trajectory and then calculate the rewards back to the previous steps. This only works within batch processing if the end of the trajectory is also included in the batch.
Given reset_at_each_iter=True, I assume that each batch starts with freshly reset environments and therefore no trajectories from previous batches are continued. However, some trajectories may not terminate within the current batch, and I'd like to exclude those before passing the data to GAE.
My questions:
What is the best way to filter a TensorDict of shape [frames_per_batch] to retain only full episodes?
I already track "done" and "step_count" (custom "from_step_count") per frame. I currently consider filtering for trajectories that contain a done=True and start at step==0, but I’m not sure of the cleanest or most idiomatic way in TorchRL to achieve that.
Is there a recommended utility in TorchRL to extract complete trajectories from such a batched TensorDict or would a manual implementation using episode indices and slicing per trajectory be preferred?
Any best practices when combining this with GAE using shifted=True in this setup?
Thanks in advance for any guidance or example patterns.
I’d be happy to share my current code if that helps.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I'm using torchrl.collectors.MultiSyncDataCollector with the following setup:
My goal is to filter each collected batch to retain only complete trajectories, i.e., episodes that start with from_step_count == 0 and end with done == True within the same batch.
The motivation behind is: I calculate the rewards at the end of each trajectory and then calculate the rewards back to the previous steps. This only works within batch processing if the end of the trajectory is also included in the batch.
Given reset_at_each_iter=True, I assume that each batch starts with freshly reset environments and therefore no trajectories from previous batches are continued. However, some trajectories may not terminate within the current batch, and I'd like to exclude those before passing the data to GAE.
My questions:
Thanks in advance for any guidance or example patterns.
I’d be happy to share my current code if that helps.
Best regards,
Markus
Beta Was this translation helpful? Give feedback.
All reactions