1
- import jax . numpy as jnp
1
+ import numpy as np
2
2
import pytest
3
+ from scipy .spatial .transform import Rotation as R
3
4
4
5
from crazyflow .control .control import Control , state2attitude
5
6
from crazyflow .sim import Physics , Sim
@@ -11,43 +12,43 @@ def test_state_interface(physics: Physics):
11
12
sim = Sim (physics = physics , control = Control .state )
12
13
13
14
# Simple P controller for attitude to reach target height
14
- cmd = jnp .zeros ((1 , 1 , 13 ), dtype = jnp .float32 )
15
- cmd = cmd . at [0 , 0 , 2 ]. set ( 1.0 ) # Set z position target to 1.0
15
+ cmd = np .zeros ((1 , 1 , 13 ), dtype = np .float32 )
16
+ cmd [0 , 0 , 2 ] = 1.0 # Set z position target to 1.0
16
17
17
18
for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
18
19
sim .state_control (cmd )
19
20
sim .step (sim .freq // sim .control_freq )
20
- if jnp .linalg .norm (sim .data .states .pos [0 , 0 ] - jnp .array ([0.0 , 0.0 , 1.0 ])) < 0.1 :
21
+ if np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ])) < 0.1 :
21
22
break
22
23
23
24
# Check if drone reached target position
24
- distance = jnp .linalg .norm (sim .data .states .pos [0 , 0 ] - jnp .array ([0.0 , 0.0 , 1.0 ]))
25
+ distance = np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ]))
25
26
assert distance < 0.1 , f"Failed to reach target height with { physics } physics"
26
27
27
28
28
29
@pytest .mark .integration
29
30
@pytest .mark .parametrize ("physics" , Physics )
30
31
def test_attitude_interface (physics : Physics ):
31
32
sim = Sim (physics = physics , control = Control .attitude )
32
- target_pos = jnp .array ([0.0 , 0.0 , 1.0 ])
33
+ target_pos = np .array ([0.0 , 0.0 , 1.0 ])
33
34
34
- i_error = jnp .zeros ((1 , 1 , 3 ))
35
+ i_error = np .zeros ((1 , 1 , 3 ))
35
36
36
37
for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
37
38
pos , vel , quat = sim .data .states .pos , sim .data .states .vel , sim .data .states .quat
38
- des_pos = jnp .array ([[[0 , 0 , 1.0 ]]])
39
+ des_pos = np .array ([[[0 , 0 , 1.0 ]]])
39
40
dt = 1 / sim .data .controls .attitude_freq
40
41
cmd , i_error = state2attitude (
41
- pos , vel , quat , des_pos , jnp .zeros ((1 , 1 , 3 )), jnp .zeros ((1 , 1 , 1 )), i_error , dt
42
+ pos , vel , quat , des_pos , np .zeros ((1 , 1 , 3 )), np .zeros ((1 , 1 , 1 )), i_error , dt
42
43
)
43
44
sim .attitude_control (cmd )
44
45
sim .step (sim .freq // sim .control_freq )
45
- if jnp .linalg .norm (sim .data .states .pos [0 , 0 ] - target_pos ) < 0.1 :
46
+ if np .linalg .norm (sim .data .states .pos [0 , 0 ] - target_pos ) < 0.1 :
46
47
break
47
48
48
49
# Check if drone maintained hover position
49
50
dpos = sim .data .states .pos [0 , 0 ] - target_pos
50
- distance = jnp .linalg .norm (dpos )
51
+ distance = np .linalg .norm (dpos )
51
52
assert distance < 0.1 , f"Failed to maintain hover with { physics } ({ dpos } )"
52
53
53
54
@@ -56,17 +57,40 @@ def test_attitude_interface(physics: Physics):
56
57
def test_swarm_control (physics : Physics ):
57
58
n_worlds , n_drones = 2 , 3
58
59
sim = Sim (n_worlds = n_worlds , n_drones = n_drones , physics = physics , control = Control .state )
59
- target_pos = sim .data .states .pos + jnp .array ([0.2 , 0.2 , 0.2 ])
60
+ target_pos = sim .data .states .pos + np .array ([0.2 , 0.2 , 0.2 ])
60
61
61
- cmd = jnp .zeros ((n_worlds , n_drones , 13 ))
62
- cmd = cmd . at [..., :3 ]. set ( target_pos )
62
+ cmd = np .zeros ((n_worlds , n_drones , 13 ))
63
+ cmd [..., :3 ] = target_pos
63
64
64
65
for _ in range (int (3 * sim .control_freq )): # Run simulation for 2 seconds
65
66
sim .state_control (cmd )
66
67
sim .step (sim .freq // sim .control_freq )
67
- if jnp .linalg .norm (sim .data .states .pos [0 , 0 ] - target_pos ) < 0.1 :
68
+ if np .linalg .norm (sim .data .states .pos [0 , 0 ] - target_pos ) < 0.1 :
68
69
break
69
70
70
71
# Check if drone maintained hover position
71
- max_dist = jnp .max (jnp .linalg .norm (sim .data .states .pos - target_pos , axis = - 1 ))
72
+ max_dist = np .max (np .linalg .norm (sim .data .states .pos - target_pos , axis = - 1 ))
72
73
assert max_dist < 0.05 , f"Failed to reach target, max dist: { max_dist } "
74
+
75
+
76
+ @pytest .mark .integration
77
+ @pytest .mark .parametrize ("physics" , Physics )
78
+ def test_yaw_rotation (physics : Physics ):
79
+ if physics == Physics .sys_id : # TODO: Remove once yaw is supported for sys_id
80
+ pytest .skip ("Yaw != 0 currently not supported for sys_id" )
81
+
82
+ sim = Sim (physics = physics , control = Control .state )
83
+ sim .reset ()
84
+
85
+ cmd = np .zeros ((sim .n_worlds , sim .n_drones , 13 ))
86
+ cmd [..., :3 ] = 0.2
87
+ cmd [..., 9 ] = np .pi / 2 # Test if the drone can rotate in yaw
88
+
89
+ sim .state_control (cmd )
90
+ sim .step (sim .freq * 2 )
91
+ pos = sim .data .states .pos [0 , 0 ]
92
+ rot = R .from_quat (sim .data .states .quat [0 , 0 ])
93
+ distance = np .linalg .norm (pos - np .array ([0.2 , 0.2 , 0.2 ]))
94
+ assert distance < 0.1 , f"Failed to reach target, distance: { distance } "
95
+ angle = rot .as_euler ("xyz" )[2 ]
96
+ assert np .abs (angle - np .pi / 2 ) < 0.1 , f"Failed to rotate in yaw, angle: { angle } "
0 commit comments