Skip to content

Commit

Permalink
[BugFix] Fix device mismatch bug on SimplerEnv envs when using physx_…
Browse files Browse the repository at this point in the history
…cpu sim backend and sapien_cuda render backend and fix a test
  • Loading branch information
StoneT2000 committed Feb 25, 2025
1 parent 7b71ced commit 29d55aa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
19 changes: 17 additions & 2 deletions mani_skill/envs/tasks/digital_twins/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,14 @@ def _after_reconfigure(self, options: dict):
def _green_sceen_rgb(self, rgb, segmentation, overlay_img):
"""returns green screened RGB data given a batch of RGB and segmentation images and one overlay image"""
actor_seg = segmentation[..., 0]
mask = torch.ones_like(actor_seg)
mask = torch.ones_like(actor_seg, device=actor_seg.device)
if actor_seg.device != self.robot_link_ids.device:
# if using CPU simulation, the device of the robot_link_ids and target_object_actor_ids will be CPU first
# but for most users who use the sapien_cuda render backend image data will be on the GPU.
self.robot_link_ids = self.robot_link_ids.to(actor_seg.device)
self.target_object_actor_ids = self.target_object_actor_ids.to(
actor_seg.device
)
if ("background" in self.rgb_overlay_mode) or (
"debug" in self.rgb_overlay_mode
):
Expand Down Expand Up @@ -151,10 +158,18 @@ def get_obs(self, info: dict = None):
assert (
"segmentation" in obs["sensor_data"][camera_name].keys()
), "Image overlay requires segment info in the observation!"
if (
self._rgb_overlay_images[camera_name].device
!= obs["sensor_data"][camera_name]["rgb"].device
):
self._rgb_overlay_images[camera_name] = self._rgb_overlay_images[
camera_name
].to(obs["sensor_data"][camera_name]["rgb"].device)
overlay_img = self._rgb_overlay_images[camera_name]
green_screened_rgb = self._green_sceen_rgb(
obs["sensor_data"][camera_name]["rgb"],
obs["sensor_data"][camera_name]["segmentation"],
self._rgb_overlay_images[camera_name],
overlay_img,
)
obs["sensor_data"][camera_name]["rgb"] = green_screened_rgb
return obs
3 changes: 2 additions & 1 deletion tests/test_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from stable_baselines3.common.vec_env import SubprocVecEnv

import mani_skill.envs
from mani_skill.utils.wrappers.gymnasium import CPUGymWrapper
from tests.utils import VENV_OBS_MODES


Expand All @@ -15,7 +16,7 @@ def test_gymnasium_cpu_vecenv(env_id, obs_mode):
env_id,
n_envs,
obs_mode=obs_mode,
# wrappers=[FlattenObservationWrapper],
wrappers=[CPUGymWrapper],
vectorization_mode="sync",
)
np.random.seed(2022)
Expand Down

0 comments on commit 29d55aa

Please sign in to comment.