@@ -357,12 +357,6 @@ def _obs(self) -> dict[str, Array]:
357
357
obs = super ()._obs ()
358
358
obs ["difference_to_target_vel" ] = [self .target_vel - self .sim .states .vel ]
359
359
return obs
360
- < << << << Updated upstream
361
- == == == =
362
- << << << < Updated upstream
363
- == == == =
364
- >> >> >> > Stashed changes
365
-
366
360
367
361
class CrazyflowEnvLanding (CrazyflowBaseEnv ):
368
362
"""JAX Gymnasium environment for Crazyflie simulation."""
@@ -382,17 +376,6 @@ def __init__(self, **kwargs: dict):
382
376
383
377
@property
384
378
def reward (self ) -> Array :
385
- << << << < Updated upstream
386
- return self ._reward (self .terminated , self .sim .states , self .goal )
387
-
388
- @staticmethod
389
- @jax .jit
390
- def _reward (terminated : Array , states : SimState , goal : Array ) -> Array :
391
- norm_distance = jnp .linalg .norm (states .pos - goal , axis = 2 )
392
- speed = jnp .linalg .norm (states .vel , axis = 2 )
393
- reward = jnp .exp (- 2.0 * norm_distance ) * jnp .exp (- 2.0 * speed )
394
- return jnp .where (terminated , - 1.0 , reward )
395
- == == == =
396
379
return self ._reward (self .prev_done , self .terminated , self .sim .states , self .goal )
397
380
398
381
@staticmethod
@@ -404,16 +387,11 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
404
387
reward = jnp .where (terminated , - 1.0 , reward )
405
388
reward = jnp .where (prev_done , 0.0 , reward )
406
389
return reward
407
- >> >> >> > Stashed changes
408
390
409
391
def reset (self , mask : Array ) -> None :
410
392
super ().reset (mask )
411
393
412
394
def _get_obs (self ) -> dict [str , Array ]:
413
395
obs = super ()._get_obs ()
414
396
obs ["difference_to_goal" ] = [self .goal - self .sim .states .pos ]
415
- return obs
416
- << << << < Updated upstream
417
- == == == =
418
- >> >> >> > Stashed changes
419
- >> >> >> > Stashed changes
397
+ return obs
0 commit comments