Skip to content

Commit 6cb3702

Browse files
committed
Fix incorrect torque conversion in collective_torque2rpy_rates_deriv. Add tests.
1 parent 3b62ca2 commit 6cb3702

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

crazyflow/sim/physics.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def collective_force2acceleration(force: Array, mass: Array) -> Array:
8787
@partial(vectorize, signature="(3),(4),(3,3)->(3)")
8888
def collective_torque2rpy_rates_deriv(torque: Array, quat: Array, J_INV: Array) -> Array:
8989
"""Convert torques to rpy_rates_deriv."""
90-
return R.from_quat(quat).apply(J_INV @ torque)
90+
rot = R.from_quat(quat)
91+
return rot.apply(J_INV @ rot.apply(torque, inverse=True))
9192

9293

9394
@partial(vectorize, signature="(4),(4),(3),(3,3)->(3),(3)")

tests/integration/test_interfaces.py

+40-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import jax.numpy as jnp
1+
import numpy as np
22
import pytest
3+
from scipy.spatial.transform import Rotation as R
34

45
from crazyflow.control.control import Control, state2attitude
56
from crazyflow.sim import Physics, Sim
@@ -11,43 +12,43 @@ def test_state_interface(physics: Physics):
1112
sim = Sim(physics=physics, control=Control.state)
1213

1314
# Simple P controller for attitude to reach target height
14-
cmd = jnp.zeros((1, 1, 13), dtype=jnp.float32)
15-
cmd = cmd.at[0, 0, 2].set(1.0) # Set z position target to 1.0
15+
cmd = np.zeros((1, 1, 13), dtype=np.float32)
16+
cmd[0, 0, 2] = 1.0 # Set z position target to 1.0
1617

1718
for _ in range(int(2 * sim.control_freq)): # Run simulation for 2 seconds
1819
sim.state_control(cmd)
1920
sim.step(sim.freq // sim.control_freq)
20-
if jnp.linalg.norm(sim.data.states.pos[0, 0] - jnp.array([0.0, 0.0, 1.0])) < 0.1:
21+
if np.linalg.norm(sim.data.states.pos[0, 0] - np.array([0.0, 0.0, 1.0])) < 0.1:
2122
break
2223

2324
# Check if drone reached target position
24-
distance = jnp.linalg.norm(sim.data.states.pos[0, 0] - jnp.array([0.0, 0.0, 1.0]))
25+
distance = np.linalg.norm(sim.data.states.pos[0, 0] - np.array([0.0, 0.0, 1.0]))
2526
assert distance < 0.1, f"Failed to reach target height with {physics} physics"
2627

2728

2829
@pytest.mark.integration
2930
@pytest.mark.parametrize("physics", Physics)
3031
def test_attitude_interface(physics: Physics):
3132
sim = Sim(physics=physics, control=Control.attitude)
32-
target_pos = jnp.array([0.0, 0.0, 1.0])
33+
target_pos = np.array([0.0, 0.0, 1.0])
3334

34-
i_error = jnp.zeros((1, 1, 3))
35+
i_error = np.zeros((1, 1, 3))
3536

3637
for _ in range(int(2 * sim.control_freq)): # Run simulation for 2 seconds
3738
pos, vel, quat = sim.data.states.pos, sim.data.states.vel, sim.data.states.quat
38-
des_pos = jnp.array([[[0, 0, 1.0]]])
39+
des_pos = np.array([[[0, 0, 1.0]]])
3940
dt = 1 / sim.data.controls.attitude_freq
4041
cmd, i_error = state2attitude(
41-
pos, vel, quat, des_pos, jnp.zeros((1, 1, 3)), jnp.zeros((1, 1, 1)), i_error, dt
42+
pos, vel, quat, des_pos, np.zeros((1, 1, 3)), np.zeros((1, 1, 1)), i_error, dt
4243
)
4344
sim.attitude_control(cmd)
4445
sim.step(sim.freq // sim.control_freq)
45-
if jnp.linalg.norm(sim.data.states.pos[0, 0] - target_pos) < 0.1:
46+
if np.linalg.norm(sim.data.states.pos[0, 0] - target_pos) < 0.1:
4647
break
4748

4849
# Check if drone maintained hover position
4950
dpos = sim.data.states.pos[0, 0] - target_pos
50-
distance = jnp.linalg.norm(dpos)
51+
distance = np.linalg.norm(dpos)
5152
assert distance < 0.1, f"Failed to maintain hover with {physics} ({dpos})"
5253

5354

@@ -56,17 +57,40 @@ def test_attitude_interface(physics: Physics):
5657
def test_swarm_control(physics: Physics):
5758
n_worlds, n_drones = 2, 3
5859
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, control=Control.state)
59-
target_pos = sim.data.states.pos + jnp.array([0.2, 0.2, 0.2])
60+
target_pos = sim.data.states.pos + np.array([0.2, 0.2, 0.2])
6061

61-
cmd = jnp.zeros((n_worlds, n_drones, 13))
62-
cmd = cmd.at[..., :3].set(target_pos)
62+
cmd = np.zeros((n_worlds, n_drones, 13))
63+
cmd[..., :3] = target_pos
6364

6465
for _ in range(int(3 * sim.control_freq)): # Run simulation for 2 seconds
6566
sim.state_control(cmd)
6667
sim.step(sim.freq // sim.control_freq)
67-
if jnp.linalg.norm(sim.data.states.pos[0, 0] - target_pos) < 0.1:
68+
if np.linalg.norm(sim.data.states.pos[0, 0] - target_pos) < 0.1:
6869
break
6970

7071
# Check if drone maintained hover position
71-
max_dist = jnp.max(jnp.linalg.norm(sim.data.states.pos - target_pos, axis=-1))
72+
max_dist = np.max(np.linalg.norm(sim.data.states.pos - target_pos, axis=-1))
7273
assert max_dist < 0.05, f"Failed to reach target, max dist: {max_dist}"
74+
75+
76+
@pytest.mark.integration
77+
@pytest.mark.parametrize("physics", Physics)
78+
def test_yaw_rotation(physics: Physics):
79+
if physics == Physics.sys_id: # TODO: Remove once yaw is supported for sys_id
80+
pytest.skip("Yaw != 0 currently not supported for sys_id")
81+
82+
sim = Sim(physics=physics, control=Control.state)
83+
sim.reset()
84+
85+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
86+
cmd[..., :3] = 0.2
87+
cmd[..., 9] = np.pi / 2 # Test if the drone can rotate in yaw
88+
89+
sim.state_control(cmd)
90+
sim.step(sim.freq * 2)
91+
pos = sim.data.states.pos[0, 0]
92+
rot = R.from_quat(sim.data.states.quat[0, 0])
93+
distance = np.linalg.norm(pos - np.array([0.2, 0.2, 0.2]))
94+
assert distance < 0.1, f"Failed to reach target, distance: {distance}"
95+
angle = rot.as_euler("xyz")[2]
96+
assert np.abs(angle - np.pi / 2) < 0.1, f"Failed to rotate in yaw, angle: {angle}"

0 commit comments

Comments
 (0)