diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bea46bb6cd4..a4f9594f389 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -575,7 +575,7 @@ def __init__( reset_when_done: bool = True, interruptor=None, ): - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True @@ -589,7 +589,7 @@ def __init__( else: env = create_env_fn if create_env_kwargs: - if not isinstance(env, _BatchedEnv): + if not isinstance(env, BatchedEnvBase): raise RuntimeError( "kwargs were passed to SyncDataCollector but they can't be set " f"on environment of type {type(create_env_fn)}." @@ -1191,11 +1191,11 @@ def state_dict(self) -> OrderedDict: `"env_state_dict"`. """ - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, _BatchedEnv): + elif isinstance(self.env, BatchedEnvBase): env_state_dict = self.env.state_dict() else: env_state_dict = OrderedDict() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e137dcd441f..29a835419f1 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -48,7 +48,7 @@ def _check_start(fun): - def decorated_fun(self: _BatchedEnv, *args, **kwargs): + def decorated_fun(self: BatchedEnvBase, *args, **kwargs): if self.is_closed: self._create_td() self._start_workers() @@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs): return super().__call__(*args, **kwargs) -class _BatchedEnv(EnvBase): +class BatchedEnvBase(EnvBase): """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. Those queries will return a list of length equal to the number of workers containing the @@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase): serial_for_single (bool, optional): if ``True``, creating a parallel environment with a single worker will return a :class:`~SerialEnv` instead. This option has no effect with :class:`~SerialEnv`. Defaults to ``False``. + non_blocking (bool, optional): if ``True``, device moves will be done using the + ``non_blocking=True`` option. Defaults to ``True`` for batched environments + on cuda devices, and ``False`` otherwise. Examples: >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator @@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase): >>> env = ParallelEnv(2, [ ... lambda: DMControlEnv("humanoid", "stand"), ... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands - >>> r = env.rollout(10) # executes 10 random steps in the environment - >>> r[0] # data for Humanoid stand + >>> rollout = env.rollout(10) # executes 10 random steps in the environment + >>> rollout[0] # data for Humanoid stand TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) - >>> r[1] # data for Humanoid walk + >>> rollout[1] # data for Humanoid walk TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) + >>> # serial_for_single to avoid creating parallel envs if not necessary >>> env = ParallelEnv(1, make_env, serial_for_single=True) >>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary """ @@ -270,6 +274,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, serial_for_single: bool = False, + non_blocking: bool = False, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -327,6 +332,15 @@ def __init__( # self._prepare_dummy_env(create_env_fn, create_env_kwargs) self._properties_set = False self._get_metadata(create_env_fn, create_env_kwargs) + self._non_blocking = non_blocking + + @property + def non_blocking(self): + nb = self._non_blocking + if nb is None: + nb = self.device is not None and self.device.type == "cuda" + self._non_blocking = nb + return nb def _get_metadata( self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] @@ -658,6 +672,7 @@ def start(self) -> None: self._start_workers() def to(self, device: DEVICE_TYPING): + self._non_blocking = None device = torch.device(device) if device == self.device: return self @@ -679,10 +694,10 @@ def to(self, device: DEVICE_TYPING): return self -class SerialEnv(_BatchedEnv): +class SerialEnv(BatchedEnvBase): """Creates a series of environments in the same process.""" - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ _share_memory = False @@ -773,7 +788,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: else: env_device = _env.device if env_device != self.device and env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=True) + tensordict_ = tensordict_.to( + env_device, non_blocking=self.non_blocking + ) else: tensordict_ = tensordict_.clone(False) else: @@ -807,7 +824,7 @@ def select_and_clone(name, tensor): if device is None: out = out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -828,7 +845,9 @@ def _step( # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in = tensordict_in[i].to(env_device, non_blocking=True) + data_in = tensordict_in[i].to( + env_device, non_blocking=self.non_blocking + ) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) @@ -849,7 +868,7 @@ def select_and_clone(name, tensor): if device is None: out = out.clear_device_() elif out.device != device: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out def __getattr__(self, attr: str) -> Any: @@ -895,14 +914,14 @@ def to(self, device: DEVICE_TYPING): return self -class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): +class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """Creates one environment per process. TensorDicts are passed via shared memory or memory map. """ - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ __doc__ += """ .. warning:: @@ -1178,14 +1197,14 @@ def step_and_maybe_reset( tensordict_ = tensordict_.clone() elif device is not None: next_td = next_td._fast_apply( - lambda x: x.to(device, non_blocking=True) + lambda x: x.to(device, non_blocking=self.non_blocking) if x.device != device else x.clone(), device=device, # filter_empty=True, ) tensordict_ = tensordict_._fast_apply( - lambda x: x.to(device, non_blocking=True) + lambda x: x.to(device, non_blocking=self.non_blocking) if x.device != device else x.clone(), device=device, @@ -1250,7 +1269,7 @@ def select_and_clone(name, tensor): if device is None: out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out @_check_start @@ -1337,7 +1356,7 @@ def select_and_clone(name, tensor): if device is None: out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out @_check_start @@ -1657,12 +1676,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): child_pipe.send(("_".join([cmd, "done"]), None)) -def _update_cuda(t_dest, t_source): - if t_source is None: - return - t_dest.copy_(t_source.pin_memory(), non_blocking=True) - return - - def _filter_empty(tensordict): return tensordict.select(*tensordict.keys(True, True)) + + +# Create an alias for possible imports +_BatchedEnv = BatchedEnvBase