Skip to content

Commit 20cbbce

Browse files
committed
[wip,broken] Accelerate core env
1 parent 1c845e8 commit 20cbbce

File tree

4 files changed

+391
-1004
lines changed

4 files changed

+391
-1004
lines changed

benchmarks/sim.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@
7373

7474
multi_drone_env_setup_code = """
7575
import gymnasium
76+
import jax
7677
7778
import lsy_drone_racing
7879
7980
env = gymnasium.make('MultiDroneRacing-v0',
80-
n_envs=1000, # TODO: Remove this for single-world envs
81+
n_envs=1, # TODO: Remove this for single-world envs
8182
n_drones=config.env.n_drones,
8283
freq=config.env.freq,
8384
sim_config=config.sim,
@@ -87,11 +88,18 @@
8788
randomizations=config.env.get("randomizations"),
8889
random_resets=config.env.random_resets,
8990
seed=config.env.seed,
90-
device='gpu',
91+
device='cpu',
9192
)
93+
9294
env.reset()
93-
env.step(env.action_space.sample()) # JIT compile
94-
env.reset()
95+
# JIT step
96+
env.step(env.action_space.sample())
97+
jax.block_until_ready(env.unwrapped.data)
98+
# JIT masked reset (used in autoreset)
99+
mask = env.unwrapped.data.marked_for_reset
100+
mask = mask.at[0].set(True)
101+
env.unwrapped.reset(mask=mask)
102+
jax.block_until_ready(env.unwrapped.data)
95103
env.action_space.seed(2)
96104
"""
97105

0 commit comments

Comments
 (0)