19
19
from crazyflow import Sim
20
20
from crazyflow .sim .symbolic import symbolic_attitude
21
21
from gymnasium import spaces
22
+ from jax .scipy .spatial .transform import Rotation as JaxR
22
23
from scipy .spatial .transform import Rotation as R
23
24
24
25
from lsy_drone_racing .envs .randomize import (
@@ -302,7 +303,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
302
303
@staticmethod
303
304
@jax .jit
304
305
def _obs_gates (
305
- gates_visited : NDArray ,
306
+ visited : NDArray ,
306
307
drone_pos : Array ,
307
308
mocap_pos : Array ,
308
309
mocap_quat : Array ,
@@ -312,14 +313,12 @@ def _obs_gates(
312
313
nominal_rpy : NDArray ,
313
314
) -> tuple [Array , Array , Array ]:
314
315
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
315
- real_quat = mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]
316
- real_rpy = jax .scipy .spatial .transform .Rotation .from_quat (real_quat ).as_euler ("xyz" )
316
+ real_rpy = JaxR .from_quat (mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]).as_euler ("xyz" )
317
317
dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
318
- in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
319
- gates_visited = jp .logical_or (gates_visited , in_range )
320
- gates_pos = jp .where (gates_visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
321
- gates_rpy = jp .where (gates_visited [..., None ], real_rpy , nominal_rpy )
322
- return gates_visited , gates_pos , gates_rpy
318
+ visited = jp .logical_or (visited , jp .linalg .norm (dpos , axis = - 1 ) < sensor_range )
319
+ gates_pos = jp .where (visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
320
+ gates_rpy = jp .where (visited [..., None ], real_rpy , nominal_rpy )
321
+ return visited , gates_pos , gates_rpy
323
322
324
323
@staticmethod
325
324
@jax .jit
@@ -332,8 +331,7 @@ def _obs_obstacles(
332
331
nominal_pos : NDArray ,
333
332
) -> tuple [Array , Array ]:
334
333
dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
335
- in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
336
- visited = jp .logical_or (visited , in_range )
334
+ visited = jp .logical_or (visited , jp .linalg .norm (dpos , axis = - 1 ) < sensor_range )
337
335
return visited , jp .where (visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
338
336
339
337
def reward (self ) -> float :
@@ -375,7 +373,7 @@ def _disabled_drones(
375
373
contacts : Array ,
376
374
contact_masks : NDArray ,
377
375
) -> Array :
378
- rpy = jax . scipy . spatial . transform . Rotation .from_quat (quat ).as_euler ("xyz" )
376
+ rpy = JaxR .from_quat (quat ).as_euler ("xyz" )
379
377
disabled = jp .logical_or (disabled_drones , jp .all (pos < pos_low , axis = - 1 ))
380
378
disabled = jp .logical_or (disabled , jp .all (pos > pos_high , axis = - 1 ))
381
379
disabled = jp .logical_or (disabled , jp .all (rpy < rpy_low , axis = - 1 ))
@@ -480,8 +478,7 @@ def _gate_passed(
480
478
# TODO: Test. Cover cases with no gates.
481
479
ids = mocap_ids [target_gate % n_gates ]
482
480
gate_pos = mocap_pos [ids ]
483
- gate_quat = mocap_quat [ids ][..., [1 , 2 , 3 , 0 ]]
484
- gate_rot = jax .scipy .spatial .transform .Rotation .from_quat (gate_quat )
481
+ gate_rot = JaxR .from_quat (mocap_quat [ids ][..., [1 , 2 , 3 , 0 ]])
485
482
gate_size = (0.45 , 0.45 )
486
483
last_pos_local = gate_rot .apply (last_drone_pos - gate_pos , inverse = True )
487
484
pos_local = gate_rot .apply (drone_pos - gate_pos , inverse = True )
0 commit comments