@@ -221,8 +221,6 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
221
221
"vel" : np .array (self .sim .data .states .vel [0 , 0 ], dtype = np .float32 ),
222
222
"ang_vel" : np .array (self .sim .data .states .rpy_rates [0 , 0 ], dtype = np .float32 ),
223
223
}
224
- obs ["ang_vel" ][:] = R .from_euler ("xyz" , obs ["rpy" ]).apply (obs ["ang_vel" ], inverse = True )
225
-
226
224
obs ["target_gate" ] = self .target_gate if self .target_gate < len (self .gates ) else - 1
227
225
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
228
226
# use the actual pose, otherwise use the nominal pose.
@@ -247,9 +245,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
247
245
obstacles_pos [self .obstacles_visited ] = self .obstacles ["pos" ][self .obstacles_visited ]
248
246
obs ["obstacles_pos" ] = obstacles_pos .astype (np .float32 )
249
247
obs ["obstacles_visited" ] = self .obstacles_visited
250
-
251
- if "observation" in self .disturbances :
252
- obs = self .disturbances ["observation" ].apply (obs )
248
+ # TODO: Observation disturbances?
253
249
return obs
254
250
255
251
def reward (self ) -> float :
@@ -370,10 +366,13 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
370
366
assert not hasattr (self .sim .data , "gate_pos" )
371
367
assert not hasattr (self .sim .data , "obstacle_pos" )
372
368
373
- gate_ids = [self .sim .mj_model .body (f"gate:{ i } " ).id for i in range (n_gates )]
374
- gates ["ids" ] = gate_ids
375
- obstacle_ids = [self .sim .mj_model .body (f"obstacle:{ i } " ).id for i in range (n_obstacles )]
376
- obstacles ["ids" ] = obstacle_ids
369
+ mj_model = self .sim .mj_model
370
+ gates ["ids" ] = [mj_model .body (f"gate:{ i } " ).id for i in range (n_gates )]
371
+ gates ["mocap_ids" ] = [int (mj_model .body (f"gate:{ i } " ).mocapid ) for i in range (n_gates )]
372
+ obstacles ["ids" ] = [mj_model .body (f"obstacle:{ i } " ).id for i in range (n_obstacles )]
373
+ obstacles ["mocap_ids" ] = [
374
+ int (mj_model .body (f"obstacle:{ i } " ).mocapid ) for i in range (n_obstacles )
375
+ ]
377
376
378
377
def gate_passed (self ) -> bool :
379
378
"""Check if the drone has passed a gate.
@@ -383,12 +382,12 @@ def gate_passed(self) -> bool:
383
382
"""
384
383
if self .n_gates <= 0 or self .target_gate >= self .n_gates or self .target_gate == - 1 :
385
384
return False
386
- gate_id = self .gates ["ids " ][self .target_gate ]
387
- gate_pos = self .sim .data .mjx_data .mocap_pos [0 , gate_id ]
388
- gate_quat = self .sim .data .mjx_data .mocap_quat [0 , gate_id ][..., [ 3 , 0 , 1 , 2 ]]
385
+ gate_mj_id = self .gates ["mocap_ids " ][self .target_gate ]
386
+ gate_pos = self .sim .data .mjx_data .mocap_pos [0 , gate_mj_id ]. squeeze ()
387
+ gate_rot = R . from_quat ( self .sim .data .mjx_data .mocap_quat [0 , gate_mj_id ], scalar_first = True )
389
388
drone_pos = self .sim .data .states .pos [0 , 0 ]
390
389
gate_size = (0.45 , 0.45 )
391
- return check_gate_pass (gate_pos , gate_quat , gate_size , drone_pos , self ._last_drone_pos )
390
+ return check_gate_pass (gate_pos , gate_rot , gate_size , drone_pos , self ._last_drone_pos )
392
391
393
392
def close (self ):
394
393
"""Close the environment by stopping the drone and landing back at the starting position."""
0 commit comments