Skip to content

Commit 1af6964

Browse files
committed
[wip,broken] Temporary workaround for lsy_models compatibility
1 parent 82699ce commit 1af6964

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

benchmark/op_count.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ def main():
88

99
compiled_reset = sim._reset.lower(sim.data, sim.default_data, None).compile()
1010
compiled_step = sim._step.lower(sim.data, 1).compile()
11-
op_count_reset = compiled_reset.cost_analysis()[0]["flops"]
12-
op_count_step = compiled_step.cost_analysis()[0]["flops"]
11+
op_count_reset = compiled_reset.cost_analysis()["flops"]
12+
op_count_step = compiled_step.cost_analysis()["flops"]
1313
print(f"Op counts:\n Reset: {op_count_reset}\n Step: {op_count_step}")
1414

1515

crazyflow/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88
ARM_LEN: float = 0.0325 * np.sqrt(2)
99
MIX_MATRIX: NDArray = np.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
1010
SIGN_MIX_MATRIX: NDArray = np.sign(MIX_MATRIX)
11-
MASS: float = 0.027
11+
MASS: float = 0.03253
1212
J: NDArray = np.array([[2.3951e-5, 0, 0], [0, 2.3951e-5, 0], [0, 0, 3.2347e-5]])
1313
J_INV: NDArray = np.linalg.inv(J)

crazyflow/sim/structs.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import jax.numpy as jnp
77
from flax.struct import dataclass, field
88
from jax import Array, Device
9+
from lsy_models.utils.constants import Constants
910

1011
if TYPE_CHECKING:
1112
from mujoco.mjx import Data, Model
13+
from numpy.typing import NDArray
1214

1315

1416
@dataclass
@@ -161,6 +163,16 @@ class SimParams:
161163
J_INV: Array # (N, M, 3, 3)
162164
"""Inverse of the inertia matrix of the drone."""
163165

166+
# TODO: Remove duplicate definition of constants. Move into constants from lsy_models
167+
THRUST_TAU: float = field(pytree_node=False)
168+
SIGN_MATRIX: NDArray = field(pytree_node=False)
169+
L: float = field(pytree_node=False)
170+
KF: float = field(pytree_node=False)
171+
KM: float = field(pytree_node=False)
172+
GRAVITY_VEC: NDArray = field(pytree_node=False)
173+
MASS: float = field(pytree_node=False)
174+
J_inv: NDArray = field(pytree_node=False)
175+
164176
@staticmethod
165177
def create(
166178
n_worlds: int, n_drones: int, mass: float, J: Array, J_INV: Array, device: Device
@@ -170,7 +182,21 @@ def create(
170182
j, j_inv = jnp.array(J, device=device), jnp.array(J_INV, device=device)
171183
J = jnp.tile(j[None, None, :, :], (n_worlds, n_drones, 1, 1))
172184
J_INV = jnp.tile(j_inv[None, None, :, :], (n_worlds, n_drones, 1, 1))
173-
return SimParams(mass=mass, J=J, J_INV=J_INV)
185+
constants = Constants.from_config("cf2x_L250")
186+
187+
return SimParams(
188+
mass=mass,
189+
J=constants.J,
190+
J_INV=J_INV,
191+
THRUST_TAU=constants.THRUST_TAU,
192+
SIGN_MATRIX=constants.SIGN_MATRIX,
193+
L=constants.L,
194+
KF=constants.KF,
195+
KM=constants.KM,
196+
GRAVITY_VEC=constants.GRAVITY_VEC,
197+
MASS=constants.MASS,
198+
J_inv=constants.J_inv,
199+
)
174200

175201

176202
@dataclass

0 commit comments

Comments
 (0)