30
30
randomize_gate_rpy_fn ,
31
31
randomize_obstacle_pos_fn ,
32
32
)
33
- from lsy_drone_racing .utils import check_gate_pass
34
33
35
34
if TYPE_CHECKING :
36
35
from crazyflow .sim .structs import SimData
@@ -227,7 +226,7 @@ def step(
227
226
self .sim .step (self .sim .freq // self .freq )
228
227
# TODO: Clean up the accelerated functions
229
228
self .disabled_drones = np .array (
230
- self .update_active_drones_acc (
229
+ self ._disabled_drones (
231
230
self .sim .data .states .pos [0 ],
232
231
self .sim .data .states .quat [0 ],
233
232
self .pos_bounds .low ,
@@ -242,7 +241,7 @@ def step(
242
241
)
243
242
self .sim .data = self .warp_disabled_drones (self .sim .data , self .disabled_drones )
244
243
# TODO: Clean up the accelerated functions
245
- passed = self .gate_passed_accelerated (
244
+ passed = self ._gate_passed (
246
245
self .target_gate ,
247
246
self .gates ["mocap_ids" ],
248
247
self .sim .data .mjx_data .mocap_pos [0 ],
@@ -263,21 +262,11 @@ def render(self):
263
262
def obs (self ) -> dict [str , NDArray [np .floating ]]:
264
263
"""Return the observation of the environment."""
265
264
# TODO: Accelerate this function
266
- obs = {
267
- "pos" : np .array (self .sim .data .states .pos [0 ], dtype = np .float32 ),
268
- "rpy" : R .from_quat (self .sim .data .states .quat [0 ]).as_euler ("xyz" ).astype (np .float32 ),
269
- "vel" : np .array (self .sim .data .states .vel [0 ], dtype = np .float32 ),
270
- "ang_vel" : np .array (self .sim .data .states .rpy_rates [0 ], dtype = np .float32 ),
271
- }
272
- obs ["target_gate" ] = self .target_gate
273
265
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
274
266
# use the actual pose, otherwise use the nominal pose.
275
- drone_pos = self .sim .data .states .pos [0 ]
276
- # Performance optimization: Get a continuous slice instead of using a list of indices which
277
- # copies the data. Assumes that the mocap ids are consecutive.
278
- gates_visited , gates_pos , gates_rpy = self .obs_acc_gates (
267
+ gates_visited , gates_pos , gates_rpy = self ._obs_gates (
279
268
self .gates_visited ,
280
- drone_pos ,
269
+ self . sim . data . states . pos [ 0 ] ,
281
270
self .sim .data .mjx_data .mocap_pos [0 ],
282
271
self .sim .data .mjx_data .mocap_quat [0 ],
283
272
self .gates ["mocap_ids" ],
@@ -286,62 +275,66 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
286
275
self .gates ["nominal_rpy" ],
287
276
)
288
277
self .gates_visited = np .asarray (gates_visited , dtype = bool )
289
- obs ["gates_pos" ] = np .asarray (gates_pos , dtype = np .float32 )
290
- obs ["gates_rpy" ] = np .asarray (gates_rpy , dtype = np .float32 )
291
- obs ["gates_visited" ] = self .gates_visited
292
-
293
- obstacles_visited , obstacles_pos = self .obs_acc_obstacles (
278
+ obstacles_visited , obstacles_pos = self ._obs_obstacles (
294
279
self .obstacles_visited ,
295
- drone_pos ,
280
+ self . sim . data . states . pos [ 0 ] ,
296
281
self .sim .data .mjx_data .mocap_pos [0 ],
297
282
self .obstacles ["mocap_ids" ],
298
283
self .sensor_range ,
299
284
self .obstacles ["nominal_pos" ],
300
285
)
301
286
self .obstacles_visited = np .asarray (obstacles_visited , dtype = bool )
302
- obs ["obstacles_pos" ] = np .asarray (obstacles_pos , dtype = np .float32 )
303
- obs ["obstacles_visited" ] = self .obstacles_visited
304
287
# TODO: Decide on observation disturbances
288
+ obs = {
289
+ "pos" : np .array (self .sim .data .states .pos [0 ], dtype = np .float32 ),
290
+ "rpy" : R .from_quat (self .sim .data .states .quat [0 ]).as_euler ("xyz" ).astype (np .float32 ),
291
+ "vel" : np .array (self .sim .data .states .vel [0 ], dtype = np .float32 ),
292
+ "ang_vel" : np .array (self .sim .data .states .rpy_rates [0 ], dtype = np .float32 ),
293
+ "target_gate" : self .target_gate ,
294
+ "gates_pos" : np .asarray (gates_pos , dtype = np .float32 ),
295
+ "gates_rpy" : np .asarray (gates_rpy , dtype = np .float32 ),
296
+ "gates_visited" : self .gates_visited ,
297
+ "obstacles_pos" : np .asarray (obstacles_pos , dtype = np .float32 ),
298
+ "obstacles_visited" : self .obstacles_visited ,
299
+ }
305
300
return obs
306
301
307
302
@staticmethod
308
303
@jax .jit
309
- def obs_acc_gates (
310
- gates_visited ,
311
- drone_pos ,
312
- mocap_pos ,
313
- mocap_quat ,
314
- mocap_ids ,
315
- sensor_range ,
316
- nominal_pos ,
317
- nominal_rpy ,
318
- ):
319
- # TODO: Clean up the accelerated functions
320
- gates_pos = mocap_pos [mocap_ids ]
321
- gates_quat = mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]
322
- gates_rpy = jax .scipy .spatial .transform .Rotation .from_quat (gates_quat ).as_euler ("xyz" )
323
- dpos = drone_pos [..., None , :2 ] - gates_pos [:, :2 ]
304
+ def _obs_gates (
305
+ gates_visited : NDArray ,
306
+ drone_pos : Array ,
307
+ mocap_pos : Array ,
308
+ mocap_quat : Array ,
309
+ mocap_ids : NDArray ,
310
+ sensor_range : float ,
311
+ nominal_pos : NDArray ,
312
+ nominal_rpy : NDArray ,
313
+ ) -> tuple [Array , Array , Array ]:
314
+ """Get the nominal or real gate positions and orientations depending on the sensor range."""
315
+ real_quat = mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]
316
+ real_rpy = jax .scipy .spatial .transform .Rotation .from_quat (real_quat ).as_euler ("xyz" )
317
+ dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
324
318
in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
325
319
gates_visited = jp .logical_or (gates_visited , in_range )
326
-
327
- mask = gates_visited [..., None ]
328
- gates_pos = jp .where (mask , gates_pos , nominal_pos )
329
- gates_rpy = jp .where (mask , gates_rpy , nominal_rpy )
320
+ gates_pos = jp .where (gates_visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
321
+ gates_rpy = jp .where (gates_visited [..., None ], real_rpy , nominal_rpy )
330
322
return gates_visited , gates_pos , gates_rpy
331
323
332
324
@staticmethod
333
325
@jax .jit
334
- def obs_acc_obstacles (
335
- obstacles_visited , drone_pos , mocap_pos , mocap_ids , sensor_range , nominal_pos
336
- ):
337
- # TODO: Clean up the accelerated functions
338
- obstacles_pos = mocap_pos [mocap_ids ]
339
- dpos = drone_pos [..., None , :2 ] - obstacles_pos [:, :2 ]
326
+ def _obs_obstacles (
327
+ visited : NDArray ,
328
+ drone_pos : Array ,
329
+ mocap_pos : Array ,
330
+ mocap_ids : NDArray ,
331
+ sensor_range : float ,
332
+ nominal_pos : NDArray ,
333
+ ) -> tuple [Array , Array ]:
334
+ dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
340
335
in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
341
- obstacles_visited = jp .logical_or (obstacles_visited , in_range )
342
- mask = obstacles_visited [..., None ]
343
- obstacles_pos = jp .where (mask , obstacles_pos , nominal_pos )
344
- return obstacles_visited , obstacles_pos
336
+ visited = jp .logical_or (visited , in_range )
337
+ return visited , jp .where (visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
345
338
346
339
def reward (self ) -> float :
347
340
"""Compute the reward for the current state.
@@ -368,33 +361,20 @@ def info(self) -> dict:
368
361
"""Return an info dictionary containing additional information about the environment."""
369
362
return {"collisions" : np .any (self .sim .contacts (), axis = - 1 ), "symbolic_model" : self .symbolic }
370
363
371
- def update_active_drones (self ):
372
- # TODO: Accelerate
373
- pos = self .sim .data .states .pos [0 , ...]
374
- rpy = R .from_quat (self .sim .data .states .quat [0 , ...]).as_euler ("xyz" )
375
- disabled = np .logical_or (self .disabled_drones , np .all (pos < self .pos_bounds .low , axis = - 1 ))
376
- disabled = np .logical_or (disabled , np .all (pos > self .pos_bounds .high , axis = - 1 ))
377
- disabled = np .logical_or (disabled , np .all (rpy < self .rpy_bounds .low , axis = - 1 ))
378
- disabled = np .logical_or (disabled , np .all (rpy > self .rpy_bounds .high , axis = - 1 ))
379
- disabled = np .logical_or (disabled , self .target_gate == - 1 )
380
- contacts = np .any (np .logical_and (self .sim .contacts (), self .contact_masks ), axis = - 1 )
381
- disabled = np .logical_or (disabled , contacts )
382
- self .disabled_drones = disabled
383
-
384
364
@staticmethod
385
365
@jax .jit
386
- def update_active_drones_acc (
387
- pos ,
388
- quat ,
389
- pos_low ,
390
- pos_high ,
391
- rpy_low ,
392
- rpy_high ,
393
- target_gate ,
394
- disabled_drones ,
395
- contacts ,
396
- contact_masks ,
397
- ):
366
+ def _disabled_drones (
367
+ pos : Array ,
368
+ quat : Array ,
369
+ pos_low : NDArray ,
370
+ pos_high : NDArray ,
371
+ rpy_low : NDArray ,
372
+ rpy_high : NDArray ,
373
+ target_gate : NDArray ,
374
+ disabled_drones : NDArray ,
375
+ contacts : Array ,
376
+ contact_masks : NDArray ,
377
+ ) -> Array :
398
378
rpy = jax .scipy .spatial .transform .Rotation .from_quat (quat ).as_euler ("xyz" )
399
379
disabled = jp .logical_or (disabled_drones , jp .all (pos < pos_low , axis = - 1 ))
400
380
disabled = jp .logical_or (disabled , jp .all (pos > pos_high , axis = - 1 ))
@@ -481,30 +461,9 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
481
461
mocap_ids = [int (mj_model .body (f"obstacle:{ i } " ).mocapid ) for i in range (n_obstacles )]
482
462
obstacles ["mocap_ids" ] = np .array (mocap_ids , dtype = np .int32 )
483
463
484
- def gate_passed (self ) -> bool :
485
- """Check if the drone has passed a gate.
486
-
487
- Returns:
488
- True if the drone has passed a gate, else False.
489
- """
490
- passed = np .zeros (self .sim .n_drones , dtype = bool )
491
- if self .n_gates <= 0 :
492
- return passed
493
- gate_ids = self .target_gate % self .n_gates
494
- gate_mj_id = self .gates ["mocap_ids" ][gate_ids ]
495
- gate_pos = self .sim .data .mjx_data .mocap_pos [0 , gate_mj_id ].squeeze ()
496
- gate_rot = R .from_quat (self .sim .data .mjx_data .mocap_quat [0 , gate_mj_id ], scalar_first = True )
497
- drone_pos = self .sim .data .states .pos [0 ]
498
- gate_size = (0.45 , 0.45 )
499
- for i in range (self .sim .n_drones ):
500
- passed [i ] = check_gate_pass (
501
- gate_pos [i ], gate_rot [i ], gate_size , drone_pos [i ], self ._last_drone_pos [i ]
502
- )
503
- return passed
504
-
505
464
@staticmethod
506
465
@jax .jit
507
- def gate_passed_accelerated (
466
+ def _gate_passed (
508
467
target_gate : NDArray ,
509
468
mocap_ids : NDArray ,
510
469
mocap_pos : Array ,
@@ -518,18 +477,16 @@ def gate_passed_accelerated(
518
477
Returns:
519
478
True if the drone has passed a gate, else False.
520
479
"""
521
- # TODO: Test, refactor, optimize. Cover cases with no gates.
522
- gate_ids = target_gate % n_gates
523
- gate_mj_id = mocap_ids [gate_ids ]
524
- gate_pos = mocap_pos [gate_mj_id ]
525
- gate_rot = jax .scipy .spatial .transform .Rotation .from_quat (
526
- mocap_quat [gate_mj_id ][..., [1 , 2 , 3 , 0 ]]
527
- )
480
+ # TODO: Test. Cover cases with no gates.
481
+ ids = mocap_ids [target_gate % n_gates ]
482
+ gate_pos = mocap_pos [ids ]
483
+ gate_quat = mocap_quat [ids ][..., [1 , 2 , 3 , 0 ]]
484
+ gate_rot = jax .scipy .spatial .transform .Rotation .from_quat (gate_quat )
528
485
gate_size = (0.45 , 0.45 )
529
486
last_pos_local = gate_rot .apply (last_drone_pos - gate_pos , inverse = True )
530
487
pos_local = gate_rot .apply (drone_pos - gate_pos , inverse = True )
531
- # Check the plane intersection. If passed, calculate the point of the intersection and check if
532
- # it is within the gate box.
488
+ # Check if the line between the last position and the current position intersects the plane.
489
+ # If so, calculate the point of the intersection and check if it is within the gate box.
533
490
passed_plane = (last_pos_local [..., 1 ] < 0 ) & (pos_local [..., 1 ] > 0 )
534
491
alpha = - last_pos_local [..., 1 ] / (pos_local [..., 1 ] - last_pos_local [..., 1 ])
535
492
x_intersect = alpha * (pos_local [..., 0 ]) + (1 - alpha ) * last_pos_local [..., 0 ]
@@ -540,6 +497,7 @@ def gate_passed_accelerated(
540
497
@staticmethod
541
498
@jax .jit
542
499
def warp_disabled_drones (data : SimData , mask : NDArray ) -> SimData :
500
+ """Warp the disabled drones below the ground."""
543
501
mask = mask .reshape ((1 , - 1 , 1 ))
544
502
pos = jax .numpy .where (mask , - 1 , data .states .pos )
545
503
return data .replace (states = data .states .replace (pos = pos ))
0 commit comments