|
19 | 19 |
|
20 | 20 | env_setup_code = """
|
21 | 21 | import gymnasium
|
| 22 | +import jax |
22 | 23 |
|
23 | 24 | import lsy_drone_racing
|
24 | 25 |
|
25 |
| -env = gymnasium.make( |
| 26 | +env = gymnasium.make_vec( |
26 | 27 | config.env.id,
|
| 28 | + num_envs={num_envs}, |
27 | 29 | freq=config.env.freq,
|
28 | 30 | sim_config=config.sim,
|
29 | 31 | sensor_range=config.env.sensor_range,
|
|
32 | 34 | randomizations=config.env.get("randomizations"),
|
33 | 35 | random_resets=config.env.random_resets,
|
34 | 36 | seed=config.env.seed,
|
| 37 | + device='{device}', |
35 | 38 | )
|
| 39 | +
|
| 40 | +# JIT compile the reset and step functions |
36 | 41 | env.reset()
|
37 |
| -env.step(env.action_space.sample()) # JIT compile |
38 |
| -env.reset() |
39 |
| -env.action_space.seed(42) |
40 |
| -action = env.action_space.sample() |
| 42 | +env.step(env.action_space.sample()) |
| 43 | +jax.block_until_ready(env.unwrapped.data) |
| 44 | +# JIT masked reset (used in autoreset) |
| 45 | +mask = env.unwrapped.data.marked_for_reset |
| 46 | +mask = mask.at[0].set(True) |
| 47 | +env.unwrapped._reset(mask=mask) # enforce masked reset compile |
| 48 | +jax.block_until_ready(env.unwrapped.data) |
| 49 | +env.action_space.seed(2) |
41 | 50 | """
|
42 | 51 |
|
43 | 52 | attitude_env_setup_code = """
|
44 | 53 | import gymnasium
|
| 54 | +import jax |
45 | 55 |
|
46 | 56 | import lsy_drone_racing
|
47 | 57 |
|
48 |
| -env = gymnasium.make('DroneRacingAttitude-v0', |
49 |
| - config.env.id, |
| 58 | +env = gymnasium.make_vec('DroneRacingAttitude-v0', |
| 59 | + num_envs={num_envs}, |
50 | 60 | freq=config.env.freq,
|
51 | 61 | sim_config=config.sim,
|
52 | 62 | sensor_range=config.env.sensor_range,
|
|
55 | 65 | randomizations=config.env.get("randomizations"),
|
56 | 66 | random_resets=config.env.random_resets,
|
57 | 67 | seed=config.env.seed,
|
| 68 | + device='{device}', |
58 | 69 | )
|
| 70 | +
|
| 71 | +# JIT compile the reset and step functions |
59 | 72 | env.reset()
|
60 |
| -env.step(env.action_space.sample()) # JIT compile |
61 |
| -env.reset() |
62 |
| -env.action_space.seed(42) |
63 |
| -action = env.action_space.sample() |
| 73 | +env.step(env.action_space.sample()) |
| 74 | +jax.block_until_ready(env.unwrapped.data) |
| 75 | +# JIT masked reset (used in autoreset) |
| 76 | +mask = env.unwrapped.data.marked_for_reset |
| 77 | +mask = mask.at[0].set(True) |
| 78 | +env.unwrapped._reset(mask=mask) # enforce masked reset compile |
| 79 | +jax.block_until_ready(env.unwrapped.data) |
| 80 | +env.action_space.seed(2) |
64 | 81 | """
|
65 | 82 |
|
66 | 83 | load_multi_drone_config_code = f"""
|
|
76 | 93 | import jax
|
77 | 94 |
|
78 | 95 | import lsy_drone_racing
|
79 |
| -from lsy_drone_racing.envs.multi_drone_race import MultiDroneRaceEnv |
| 96 | +from lsy_drone_racing.envs.multi_drone_race import VecMultiDroneRaceEnv |
80 | 97 |
|
81 |
| -env = gymnasium.make('MultiDroneRacing-v0', |
| 98 | +
|
| 99 | +env = gymnasium.make_vec('MultiDroneRacing-v0', |
| 100 | + num_envs={num_envs}, |
82 | 101 | n_drones=config.env.n_drones,
|
83 | 102 | freq=config.env.freq,
|
84 | 103 | sim_config=config.sim,
|
|
88 | 107 | randomizations=config.env.get("randomizations"),
|
89 | 108 | random_resets=config.env.random_resets,
|
90 | 109 | seed=config.env.seed,
|
91 |
| - device='cpu', |
| 110 | + device='{device}', |
92 | 111 | )
|
93 | 112 |
|
| 113 | +# JIT compile the reset and step functions |
94 | 114 | env.reset()
|
95 |
| -# JIT step |
96 | 115 | env.step(env.action_space.sample())
|
97 | 116 | jax.block_until_ready(env.unwrapped.data)
|
98 | 117 | # JIT masked reset (used in autoreset)
|
99 | 118 | mask = env.unwrapped.data.marked_for_reset
|
100 | 119 | mask = mask.at[0].set(True)
|
101 |
| -super(MultiDroneRaceEnv, env.unwrapped).reset(mask=mask) # enforce masked reset compile |
| 120 | +env.unwrapped._reset(mask=mask) # enforce masked reset compile |
102 | 121 | jax.block_until_ready(env.unwrapped.data)
|
103 | 122 | env.action_space.seed(2)
|
104 | 123 | """
|
105 | 124 |
|
106 | 125 |
|
107 |
| -def time_sim_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]: |
108 |
| - setup = load_config_code + env_setup_code |
| 126 | +def time_sim_reset( |
| 127 | + n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu" |
| 128 | +) -> NDArray[np.floating]: |
| 129 | + setup = load_config_code + env_setup_code.format(num_envs=n_envs, device=device) |
109 | 130 | stmt = """env.reset()"""
|
110 | 131 | return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
|
111 | 132 |
|
112 | 133 |
|
113 | 134 | def time_sim_step(
|
114 |
| - n_tests: int = 10, number: int = 1, physics_mode: str = "analytical" |
| 135 | + n_tests: int = 10, |
| 136 | + number: int = 1, |
| 137 | + physics_mode: str = "analytical", |
| 138 | + n_envs: int = 1, |
| 139 | + device: str = "cpu", |
115 | 140 | ) -> NDArray[np.floating]:
|
116 | 141 | modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
|
117 |
| - setup = load_config_code + modify_config_code + env_setup_code + "\nenv.reset()" |
118 |
| - stmt = """env.step(action)""" |
| 142 | + _env_setup_code = env_setup_code.format(num_envs=n_envs, device=device) |
| 143 | + setup = load_config_code + modify_config_code + _env_setup_code + "\nenv.reset()" |
| 144 | + stmt = """env.step(env.action_space.sample())""" |
119 | 145 | return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
|
120 | 146 |
|
121 | 147 |
|
122 |
| -def time_sim_attitude_step(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]: |
123 |
| - setup = load_config_code + attitude_env_setup_code + "\nenv.reset()" |
124 |
| - stmt = """env.step(action)""" |
| 148 | +def time_sim_attitude_step( |
| 149 | + n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu" |
| 150 | +) -> NDArray[np.floating]: |
| 151 | + env_setup_code = attitude_env_setup_code.format(num_envs=n_envs, device=device) |
| 152 | + setup = load_config_code + env_setup_code + "\nenv.reset()" |
| 153 | + stmt = """env.step(env.action_space.sample())""" |
125 | 154 | return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
|
126 | 155 |
|
127 | 156 |
|
128 |
| -def time_multi_drone_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]: |
129 |
| - setup = load_multi_drone_config_code + multi_drone_env_setup_code + "\nenv.reset()" |
| 157 | +def time_multi_drone_reset( |
| 158 | + n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu" |
| 159 | +) -> NDArray[np.floating]: |
| 160 | + env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device) |
| 161 | + setup = load_multi_drone_config_code + env_setup_code + "\nenv.reset()" |
130 | 162 | stmt = """env.reset()"""
|
131 | 163 | return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
|
132 | 164 |
|
133 | 165 |
|
134 | 166 | def time_multi_drone_step(
|
135 |
| - n_tests: int = 10, number: int = 100, physics_mode: str = "analytical" |
| 167 | + n_tests: int = 10, |
| 168 | + number: int = 100, |
| 169 | + physics_mode: str = "analytical", |
| 170 | + n_envs: int = 1, |
| 171 | + device: str = "cpu", |
136 | 172 | ) -> NDArray[np.floating]:
|
137 | 173 | modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
|
138 |
| - setup = ( |
139 |
| - load_multi_drone_config_code |
140 |
| - + modify_config_code |
141 |
| - + multi_drone_env_setup_code |
142 |
| - + "\nenv.reset()" |
143 |
| - ) |
| 174 | + env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device) |
| 175 | + |
| 176 | + setup = load_multi_drone_config_code + modify_config_code + env_setup_code + "\nenv.reset()" |
144 | 177 | stmt = """env.step(env.action_space.sample())"""
|
145 | 178 | return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
|
0 commit comments