Skip to content

Commit 1ec46d3

Browse files
committed
Fix render example
1 parent 3918827 commit 1ec46d3

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

examples/render.py

+37-24
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,51 @@
11
from collections import deque
22

3+
import einops
34
import mujoco
45
import numpy as np
56
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
7+
from numpy.typing import NDArray
68
from scipy.spatial.transform import Rotation as R
79

810
from crazyflow.sim import Physics, Sim
911

1012

11-
def render_traces(viewer: MujocoRenderer, pos: deque, quat: deque):
13+
def main():
14+
"""Spawn 25 drones in one world and render each with a trace behind it."""
15+
n_worlds, n_drones = 1, 25
16+
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.sys_id, device="cpu")
17+
fps = 60
18+
cmd = np.array([[[0.3, 0, 0, 0] for _ in range(sim.n_drones)]])
19+
20+
pos = deque(maxlen=16)
21+
rot = deque(maxlen=15)
22+
23+
for i in range(int(5 * sim.control_freq)):
24+
sim.attitude_control(cmd)
25+
sim.step(sim.freq // sim.control_freq)
26+
if i % 20 == 0:
27+
pos.append(sim.data.states.pos[0, :])
28+
if len(pos) > 1:
29+
rot.append(rotation_matrix_from_points(pos[-2], pos[-1]))
30+
if ((i * fps) % sim.control_freq) < fps:
31+
render_traces(sim.viewer, pos, rot)
32+
sim.render()
33+
sim.close()
34+
35+
36+
def render_traces(viewer: MujocoRenderer, pos: deque[NDArray], rot: deque[R]):
1237
"""Render traces of the drone trajectories."""
1338
if len(pos) < 2 or viewer is None:
1439
return
1540

16-
n_trace, n_drones = len(pos) - 1, len(pos[0])
17-
pos, quat = np.array(pos), np.array(quat)
41+
n_trace, n_drones = len(rot), len(pos[0])
42+
pos = np.array(pos)
1843
sizes = np.zeros((n_trace, n_drones, 3))
1944
rgbas = np.zeros((n_trace, n_drones, 4))
2045
sizes[..., 2] = np.linalg.norm(pos[1:] - pos[:-1], axis=-1)
21-
mats = R.from_quat(quat[1:].reshape(-1, 4)).as_matrix().flatten().reshape(n_trace, -1, 9)
46+
mats = np.zeros((n_trace, n_drones, 9))
47+
for i in range(n_trace):
48+
mats[i, :] = einops.rearrange(rot[i].as_matrix(), "d n m -> d (n m)")
2249
rgbas = np.zeros((n_trace, n_drones, 4))
2350
np.random.seed(0) # Ensure consistent colors
2451
rgbas[..., :3] = np.random.uniform(0, 1, (1, n_drones, 3))
@@ -35,26 +62,12 @@ def render_traces(viewer: MujocoRenderer, pos: deque, quat: deque):
3562
)
3663

3764

38-
def main():
39-
"""Spawn 25 drones in one world and render each with a trace behind it."""
40-
n_worlds, n_drones = 1, 25
41-
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.sys_id, device="cpu")
42-
fps = 60
43-
cmd = np.array([[[0.3, 0, 0, 0] for _ in range(sim.n_drones)]])
44-
45-
pos = deque(maxlen=15)
46-
quat = deque(maxlen=15)
47-
48-
for i in range(int(5 * sim.control_freq)):
49-
sim.attitude_control(cmd)
50-
sim.step(sim.freq // sim.control_freq)
51-
if i % 20 == 0:
52-
pos.append(sim.data.states.pos[0, :])
53-
quat.append(sim.data.states.quat[0, :])
54-
if ((i * fps) % sim.control_freq) < fps:
55-
render_traces(sim.viewer, pos, quat)
56-
sim.render()
57-
sim.close()
65+
def rotation_matrix_from_points(p1: NDArray, p2: NDArray) -> R:
66+
z_axis = (v := p2 - p1) / np.linalg.norm(v, axis=-1, keepdims=True)
67+
random_vector = np.random.rand(*z_axis.shape)
68+
x_axis = (v := np.cross(random_vector, z_axis)) / np.linalg.norm(v, axis=-1, keepdims=True)
69+
y_axis = np.cross(z_axis, x_axis)
70+
return R.from_matrix(np.stack((x_axis, y_axis, z_axis), axis=-1))
5871

5972

6073
if __name__ == "__main__":

0 commit comments

Comments
 (0)