Skip to content

Commit 50fb8f2

Browse files
committed
Fix docs and formatting
1 parent 024b37b commit 50fb8f2

File tree

3 files changed

+7
-21
lines changed

3 files changed

+7
-21
lines changed

crazyflow/gymnasium_envs/crazyflow.py

+3-19
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __init__(
6666
Args:
6767
num_envs: The number of environments to run in parallel.
6868
time_horizon_in_seconds: The time horizon after which episodes are truncated.
69-
**kwargs: Takes arguments that are passed to the Crazyfly simulation.
69+
physics: The crazyflow physics simulation model.
70+
freq: The frequency at which the environment is run.
71+
device: The device of the environment and the simulation.
7072
"""
7173
self.num_envs = num_envs
7274
self.device = jax.devices(device)[0]
@@ -530,21 +532,3 @@ def actions(self, actions: Array) -> Array:
530532
# Ensure actions are within the valid range of the simulation action space
531533
rescaled_actions = np.clip(rescaled_actions, self.action_sim_low, self.action_sim_high)
532534
return rescaled_actions
533-
534-
535-
def render_trajectory(viewer: MujocoRenderer | None, pos: Array) -> None:
536-
"""Render trajectory."""
537-
if viewer is None:
538-
return
539-
540-
pos = np.array(pos[0]).transpose(1, 0, 2)
541-
n_trace, n_drones = len(pos) - 1, len(pos[0])
542-
543-
for i in range(n_trace):
544-
for j in range(n_drones):
545-
viewer.viewer.add_marker(
546-
type=mujoco.mjtGeom.mjGEOM_SPHERE,
547-
size=np.array([0.02, 0.02, 0.02]),
548-
pos=pos[i][j],
549-
rgba=np.array([1, 0, 0, 0.8]),
550-
)

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[build-system]
22
requires = ["setuptools>=61.0.0", "wheel", "numpy"]
33
build-backend = "setuptools.build_meta"
4-
requires-python = "3.11" # tested in python 3.11
4+
requires-python = "3.11" # tested in python 3.11
55

66
[project]
77
name = "crazyflow"
@@ -81,6 +81,7 @@ unfixable = []
8181
"benchmark/*" = ["D100", "D103"]
8282
"tests/*" = ["D100", "D103", "D104"]
8383
"examples/*" = ["D100", "D103"]
84+
"tutorials/*" = ["D", "ANN"]
8485
# TODO: Remove once everything is stable and document
8586
"crazyflow/*" = ["D100", "D101", "D102", "D104", "D107"]
8687

tutorials/ppo/test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import random
3+
from pathlib import Path
34

45
import gymnasium
56
import gymnasium.wrappers.vector.jax_to_torch
@@ -43,7 +44,7 @@
4344
test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device)
4445

4546
# Load checkpoint
46-
checkpoint = torch.load("ppo_checkpoint.pt")
47+
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt")
4748

4849
# Create agent and load state
4950
agent = Agent(test_env).to(device)

0 commit comments

Comments
 (0)