6
6
import jax .numpy as jnp
7
7
from flax .struct import dataclass , field
8
8
from jax import Array , Device
9
+ from lsy_models .utils .constants import Constants
9
10
10
11
if TYPE_CHECKING :
11
12
from mujoco .mjx import Data , Model
13
+ from numpy .typing import NDArray
12
14
13
15
14
16
@dataclass
@@ -161,6 +163,16 @@ class SimParams:
161
163
J_INV : Array # (N, M, 3, 3)
162
164
"""Inverse of the inertia matrix of the drone."""
163
165
166
+ # TODO: Remove duplicate definition of constants. Move into constants from lsy_models
167
+ THRUST_TAU : float = field (pytree_node = False )
168
+ SIGN_MATRIX : NDArray = field (pytree_node = False )
169
+ L : float = field (pytree_node = False )
170
+ KF : float = field (pytree_node = False )
171
+ KM : float = field (pytree_node = False )
172
+ GRAVITY_VEC : NDArray = field (pytree_node = False )
173
+ MASS : float = field (pytree_node = False )
174
+ J_inv : NDArray = field (pytree_node = False )
175
+
164
176
@staticmethod
165
177
def create (
166
178
n_worlds : int , n_drones : int , mass : float , J : Array , J_INV : Array , device : Device
@@ -170,7 +182,21 @@ def create(
170
182
j , j_inv = jnp .array (J , device = device ), jnp .array (J_INV , device = device )
171
183
J = jnp .tile (j [None , None , :, :], (n_worlds , n_drones , 1 , 1 ))
172
184
J_INV = jnp .tile (j_inv [None , None , :, :], (n_worlds , n_drones , 1 , 1 ))
173
- return SimParams (mass = mass , J = J , J_INV = J_INV )
185
+ constants = Constants .from_config ("cf2x_L250" )
186
+
187
+ return SimParams (
188
+ mass = mass ,
189
+ J = constants .J ,
190
+ J_INV = J_INV ,
191
+ THRUST_TAU = constants .THRUST_TAU ,
192
+ SIGN_MATRIX = constants .SIGN_MATRIX ,
193
+ L = constants .L ,
194
+ KF = constants .KF ,
195
+ KM = constants .KM ,
196
+ GRAVITY_VEC = constants .GRAVITY_VEC ,
197
+ MASS = constants .MASS ,
198
+ J_inv = constants .J_inv ,
199
+ )
174
200
175
201
176
202
@dataclass
0 commit comments