Skip to content

Commit 82a70b6

Browse files
committed
Pass mocap IDs to randomize factories. Fix wrong config flags
1 parent cac75f7 commit 82a70b6

File tree

7 files changed

+46
-42
lines changed

7 files changed

+46
-42
lines changed

benchmarks/config/test.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ attitude_freq = 500 # Controller frequency, in Hz.
66
gui = false # Enable/disable PyBullet's GUI
77

88
[env]
9-
random_resets = false # Whether to re-seed the random number generator between episodes
9+
random_resets = true # Whether to re-seed the random number generator between episodes
1010
seed = 1337 # Random seed
1111
freq = 50 # Frequency of the environment's step function, in Hz
1212
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.

config/level0.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ high = [0.1, 0.1, 0.1]
3838

3939
[env]
4040
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
41-
random_resets = true # Whether to re-seed the random number generator between episodes
41+
random_resets = false # Whether to re-seed the random number generator between episodes
4242
seed = 1337 # Random seed
4343
freq = 50 # Frequency of the environment's step function, in Hz
4444
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.

config/level1.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU
2525

2626
[env]
2727
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
28-
random_resets = true # Whether to re-seed the random number generator between episodes
28+
random_resets = false # Whether to re-seed the random number generator between episodes
2929
seed = 1337 # Random seed
3030
freq = 50 # Frequency of the environment's step function, in Hz
3131
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.

config/level2.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU
2525

2626
[env]
2727
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
28-
random_resets = true # Whether to re-seed the random number generator between episodes
28+
random_resets = false # Whether to re-seed the random number generator between episodes
2929
seed = 1337 # Random seed
3030
freq = 50 # Frequency of the environment's step function, in Hz
3131
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.

config/level3.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU
2525

2626
[env]
2727
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
28-
random_resets = false # Whether to re-seed the random number generator between episodes
28+
random_resets = true # Whether to re-seed the random number generator between episodes
2929
seed = 1337 # Random seed
3030
freq = 50 # Frequency of the environment's step function, in Hz
3131
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.

lsy_drone_racing/envs/drone_racing_env.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@
3939
from gymnasium import spaces
4040
from scipy.spatial.transform import Rotation as R
4141

42-
from lsy_drone_racing.envs.randomize import randomize_sim_fn
42+
from lsy_drone_racing.envs.randomize import (
43+
randomize_drone_inertia_fn,
44+
randomize_drone_mass_fn,
45+
randomize_drone_pos_fn,
46+
randomize_drone_quat_fn,
47+
randomize_gate_pos_fn,
48+
randomize_gate_rpy_fn,
49+
randomize_obstacle_pos_fn,
50+
)
4351
from lsy_drone_racing.utils import check_gate_pass
4452

4553
if TYPE_CHECKING:
@@ -147,7 +155,7 @@ def __init__(self, config: dict):
147155
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
148156
self.n_gates = len(config.env.track.gates)
149157
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
150-
self.randomization = self.load_randomizations(config.env.get("randomization", None))
158+
self.randomizations = self.load_randomizations(config.env.get("randomization", None))
151159
self.contact_mask = np.ones((self.sim.n_worlds, 25), dtype=bool)
152160
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor
153161

@@ -250,7 +258,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
250258
obstacles_pos[self.obstacles_visited] = self.obstacles["pos"][self.obstacles_visited]
251259
obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
252260
obs["obstacles_visited"] = self.obstacles_visited
253-
# TODO: Observation disturbances?
261+
# TODO: Decide on observation disturbances
254262
return obs
255263

256264
def reward(self) -> float:
@@ -347,7 +355,9 @@ def setup_sim(self):
347355
rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
348356
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
349357
self.sim.data = self.sim.data.replace(states=states)
350-
self.sim.reset_hook = build_reset_hook(self.randomization)
358+
self.sim.reset_hook = build_reset_hook(
359+
self.randomizations, self.gates["mocap_ids"], self.obstacles["mocap_ids"]
360+
)
351361
if "dynamics" in self.disturbances:
352362
self.sim.disturbance_fn = build_dynamics_disturbance_fn(self.disturbances["dynamics"])
353363
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
@@ -400,13 +410,33 @@ def close(self):
400410
self.sim.close()
401411

402412

403-
def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array], SimData]:
413+
def build_reset_hook(
414+
randomizations: dict, gate_mocap_ids: list[int], obstacle_mocap_ids: list[int]
415+
) -> Callable[[SimData, Array], SimData]:
404416
"""Build the reset hook for the simulation."""
405-
randomizations = [randomize_sim_fn(target, rng) for target, rng in randomizations.items()]
417+
randomization_fns = []
418+
for target, rng in randomizations.items():
419+
match target:
420+
case "drone_pos":
421+
randomization_fns.append(randomize_drone_pos_fn(rng))
422+
case "drone_rpy":
423+
randomization_fns.append(randomize_drone_quat_fn(rng))
424+
case "drone_mass":
425+
randomization_fns.append(randomize_drone_mass_fn(rng))
426+
case "drone_inertia":
427+
randomization_fns.append(randomize_drone_inertia_fn(rng))
428+
case "gate_pos":
429+
randomization_fns.append(randomize_gate_pos_fn(rng, gate_mocap_ids))
430+
case "gate_rpy":
431+
randomization_fns.append(randomize_gate_rpy_fn(rng, gate_mocap_ids))
432+
case "obstacle_pos":
433+
randomization_fns.append(randomize_obstacle_pos_fn(rng, obstacle_mocap_ids))
434+
case _:
435+
raise ValueError(f"Invalid target: {target}")
406436

407437
def reset_hook(data: SimData, mask: Array) -> SimData:
408-
for randomize in randomizations:
409-
data = randomize(data, mask)
438+
for randomize_fn in randomization_fns:
439+
data = randomize_fn(data, mask)
410440
return data
411441

412442
return reset_hook

lsy_drone_racing/envs/randomize.py

+3-29
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,6 @@
1515
from jax.scipy.spatial.transform import Rotation as R
1616

1717

18-
def randomize_sim_fn(
19-
target: str, randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array]
20-
) -> Callable[[SimData, Array], SimData]:
21-
"""Create a function that randomizes aspects of the simulation."""
22-
match target:
23-
case "drone_pos":
24-
return randomize_drone_pos_fn(randomize_fn)
25-
case "drone_rpy":
26-
return randomize_drone_quat_fn(randomize_fn)
27-
case "drone_mass":
28-
return randomize_drone_mass_fn(randomize_fn)
29-
case "drone_inertia":
30-
return randomize_drone_inertia_fn(randomize_fn)
31-
case "gate_pos":
32-
return randomize_gate_pos_fn(randomize_fn)
33-
case "gate_rpy":
34-
return randomize_gate_rpy_fn(randomize_fn)
35-
case "obstacle_pos":
36-
return randomize_obstacle_pos_fn(randomize_fn)
37-
case _:
38-
raise ValueError(f"Invalid target: {target}")
39-
40-
4118
def randomize_drone_pos_fn(
4219
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
4320
) -> Callable[[SimData, Array], SimData]:
@@ -97,10 +74,9 @@ def randomize_drone_inertia(data: SimData, mask: Array) -> SimData:
9774

9875

9976
def randomize_gate_pos_fn(
100-
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
77+
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], gate_ids: list[int]
10178
) -> Callable[[SimData, Array], SimData]:
10279
"""Create a function that randomizes the gate position."""
103-
gate_ids = [0, 1, 2, 3] # TODO: Make this dynamic
10480

10581
def randomize_gate_pos(data: SimData, mask: Array) -> SimData:
10682
key, subkey = jax.random.split(data.core.rng_key)
@@ -114,10 +90,9 @@ def randomize_gate_pos(data: SimData, mask: Array) -> SimData:
11490

11591

11692
def randomize_gate_rpy_fn(
117-
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
93+
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], gate_ids: list[int]
11894
) -> Callable[[SimData, Array], SimData]:
11995
"""Create a function that randomizes the gate rotation."""
120-
gate_ids = [0, 1, 2, 3] # TODO: Make this dynamic
12196

12297
def randomize_gate_rpy(data: SimData, mask: Array) -> SimData:
12398
key, subkey = jax.random.split(data.core.rng_key)
@@ -133,10 +108,9 @@ def randomize_gate_rpy(data: SimData, mask: Array) -> SimData:
133108

134109

135110
def randomize_obstacle_pos_fn(
136-
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
111+
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], obstacle_ids: list[int]
137112
) -> Callable[[SimData, Array], SimData]:
138113
"""Create a function that randomizes the obstacle position."""
139-
obstacle_ids = [4, 5, 6, 7] # TODO: Make this dynamic
140114

141115
def randomize_obstacle_pos(data: SimData, mask: Array) -> SimData:
142116
key, subkey = jax.random.split(data.core.rng_key)

0 commit comments

Comments
 (0)