Skip to content

Commit 5d1a734

Browse files
committed
improve env interface; add landing env
1 parent 1bb1b94 commit 5d1a734

File tree

5 files changed

+61
-15
lines changed

5 files changed

+61
-15
lines changed

benchmark/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4444

4545
envs = gymnasium.make_vec(
4646
"DroneReachPos-v0",
47-
max_episode_steps=200,
47+
time_horizon_in_seconds=2,
4848
return_datatype="numpy",
4949
num_envs=sim_config.n_worlds,
5050
jax_random_key=42,

benchmark/performance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4949

5050
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
5151
"DroneReachPos-v0",
52-
max_episode_steps=200,
52+
time_horizon_in_seconds=2,
5353
return_datatype="numpy",
5454
num_envs=sim_config.n_worlds,
5555
jax_random_key=42,

crazyflow/gymnasium_envs/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from gymnasium.envs.registration import register
22

3-
from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity
3+
from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity, CrazyflowEnvLanding
44

5-
__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity"]
5+
__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity", "CrazyflowEnvLanding"]
66

77
register(
88
id="DroneReachPos-v0",
@@ -13,3 +13,9 @@
1313
id="DroneReachVel-v0",
1414
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
1515
)
16+
17+
18+
register(
19+
id="DroneLanding-v0",
20+
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvLanding",
21+
)

crazyflow/gymnasium_envs/crazyflow.py

+49-9
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
*,
5050
jax_random_key: int, # required for jax random number generator
5151
num_envs: int = 1, # required for VectorEnv
52-
max_episode_steps: int = 1000,
52+
time_horizon_in_seconds: int = 10,
5353
return_datatype: Literal["numpy", "jax"] = "jax",
5454
**kwargs: dict,
5555
):
@@ -58,7 +58,7 @@ def __init__(
5858
Args:
5959
jax_random_key: The random key for the jax random number generator.
6060
num_envs: The number of environments to run in parallel.
61-
max_episode_steps: The maximum number of steps per episode.
61+
time_horizon_in_seconds: The time horizon after which episodes are truncated.
6262
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
6363
the returned arrays will be numpy arrays on the CPU. If "jax", the returned arrays
6464
will be jax arrays on the "device" specified for the simulation.
@@ -71,7 +71,9 @@ def __init__(
7171
self.num_envs = num_envs
7272
self.return_datatype = return_datatype
7373
self.device = jax.devices(kwargs["device"])[0]
74-
self.max_episode_steps = jnp.array(max_episode_steps, dtype=jnp.int32, device=self.device)
74+
self.time_horizon_in_seconds = jnp.array(
75+
time_horizon_in_seconds, dtype=jnp.int32, device=self.device
76+
)
7577

7678
self.sim = Sim(**kwargs)
7779

@@ -224,7 +226,7 @@ def terminated(self) -> Array:
224226
@property
225227
def truncated(self) -> Array:
226228
return self._truncated(
227-
self.prev_done, self.sim.steps, self.max_episode_steps, self.n_substeps
229+
self.prev_done, self.sim.time, self.time_horizon_in_seconds, self.n_substeps
228230
)
229231

230232
def _reward() -> None:
@@ -235,17 +237,18 @@ def _reward() -> None:
235237
def _terminated(dones: Array, states: SimState, contacts: Array) -> Array:
236238
contact = jnp.any(contacts, axis=1)
237239
z_coords = states.pos[..., 2]
238-
# Sanity check if we are below the ground. Should not be triggered due to collision checking
239-
below_ground = jnp.any(z_coords < -0.1, axis=1)
240-
terminated = jnp.logical_or(below_ground, contact) # no termination condition
240+
below_ground = jnp.any(
241+
z_coords < -0.1, axis=1
242+
) # Sanity check if we are below the ground. Should not be triggered due to collision checking
243+
terminated = jnp.logical_or(below_ground, contact)
241244
return jnp.where(dones, False, terminated)
242245

243246
@staticmethod
244247
@jax.jit
245248
def _truncated(
246-
dones: Array, steps: Array, max_episode_steps: Array, n_substeps: Array
249+
dones: Array, time: Array, time_horizon_in_seconds: Array, n_substeps: Array
247250
) -> Array:
248-
truncated = steps / n_substeps >= max_episode_steps
251+
truncated = time >= time_horizon_in_seconds
249252
return jnp.where(dones, False, truncated)
250253

251254
def render(self):
@@ -352,3 +355,40 @@ def _get_obs(self) -> dict[str, Array]:
352355
obs = super()._get_obs()
353356
obs["difference_to_target_vel"] = [self.target_vel - self.sim.states.vel]
354357
return obs
358+
359+
360+
class CrazyflowEnvLanding(CrazyflowBaseEnv):
361+
"""JAX Gymnasium environment for Crazyflie simulation."""
362+
363+
def __init__(self, **kwargs: dict):
364+
assert kwargs["n_drones"] == 1, "Currently only supported for one drone"
365+
366+
super().__init__(**kwargs)
367+
self._obs_size += 3 # difference to goal position
368+
self.single_observation_space = spaces.Box(
369+
-jnp.inf, jnp.inf, shape=(self._obs_size,), dtype=jnp.float32
370+
)
371+
self.observation_space = batch_space(self.single_observation_space, self.sim.n_worlds)
372+
373+
self.goal = jnp.zeros((kwargs["n_worlds"], 3), dtype=jnp.float32)
374+
self.goal = self.goal.at[..., 2].set(0.1) # 10cm above ground
375+
376+
@property
377+
def reward(self) -> Array:
378+
return self._reward(self.terminated, self.sim.states, self.goal)
379+
380+
@staticmethod
381+
@jax.jit
382+
def _reward(terminated: Array, states: SimState, goal: Array) -> Array:
383+
norm_distance = jnp.linalg.norm(states.pos - goal, axis=2)
384+
speed = jnp.linalg.norm(states.vel, axis=2)
385+
reward = jnp.exp(-2.0 * norm_distance) * jnp.exp(-2.0 * speed)
386+
return jnp.where(terminated, -1.0, reward)
387+
388+
def reset(self, mask: Array) -> None:
389+
super().reset(mask)
390+
391+
def _get_obs(self) -> dict[str, Array]:
392+
obs = super()._get_obs()
393+
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
394+
return obs

examples/gymnasium_env.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
SEED = 42
1919

2020
envs = gymnasium.make_vec(
21-
"CrazyflowEnvReachGoal-v0",
22-
max_episode_steps=1000,
21+
"DroneLanding-v0",
22+
time_horizon_in_seconds=2,
2323
return_datatype="numpy",
2424
num_envs=sim_config.n_worlds,
2525
jax_random_key=SEED,

0 commit comments

Comments
 (0)