Skip to content

Commit 38ab0f8

Browse files
committed
Fix gymnasium examples
1 parent 9349ff2 commit 38ab0f8

File tree

2 files changed

+16
-48
lines changed

2 files changed

+16
-48
lines changed

examples/gymnasium_env.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,28 @@
11
import gymnasium
22
import numpy as np
33
from gymnasium.wrappers.vector import JaxToNumpy # , JaxToTorch
4-
from ml_collections import config_dict
54

6-
from crazyflow.control import Control
7-
from crazyflow.sim.physics import Physics
5+
import crazyflow # noqa: F401, register gymnasium envs
6+
from crazyflow.utils import enable_cache
87

98

109
def main():
11-
# set config for simulation
12-
sim_config = config_dict.ConfigDict()
13-
sim_config.device = "cpu"
14-
sim_config.physics = Physics.sys_id
15-
sim_config.control = Control.attitude
16-
sim_config.attitude_freq = 50
17-
sim_config.n_drones = 1
18-
sim_config.n_worlds = 20
19-
10+
enable_cache()
2011
SEED = 42
21-
22-
envs = gymnasium.make_vec(
23-
"DroneLanding-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
24-
)
12+
envs = gymnasium.make_vec("DroneLanding-v0", num_envs=20, freq=50, time_horizon_in_seconds=2)
2513

2614
# This wrapper makes it possible to interact with the environment using numpy arrays, if
2715
# desired. JaxToTorch is available as well.
2816
envs = JaxToNumpy(envs)
2917

3018
# dummy action for going up (in attitude control)
31-
action = np.zeros((sim_config.n_worlds * sim_config.n_drones, 4), dtype=np.float32)
19+
action = np.zeros((20, 4), dtype=np.float32)
3220
action[..., 0] = 0.4
3321

3422
obs, info = envs.reset(seed=SEED)
3523

3624
# Step through the environment
37-
for _ in range(1500):
25+
for _ in range(100):
3826
observation, reward, terminated, truncated, info = envs.step(action)
3927
envs.render()
4028

examples/gymnasium_env_trajectory.py

+10-30
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,25 @@
11
import gymnasium
22
import numpy as np
33
from gymnasium.wrappers.vector import JaxToNumpy # , JaxToTorch
4-
from ml_collections import config_dict
5-
from scipy.interpolate import splev
64

7-
from crazyflow.control import Control
85
from crazyflow.gymnasium_envs import CrazyflowRL # noqa: F401
9-
from crazyflow.sim.physics import Physics
6+
from crazyflow.utils import enable_cache
107

118

129
def main():
13-
# set config for simulation
14-
sim_config = config_dict.ConfigDict()
15-
sim_config.device = "cpu"
16-
sim_config.physics = Physics.sys_id
17-
sim_config.control = Control.attitude
18-
sim_config.attitude_freq = 50
19-
sim_config.n_drones = 1
20-
sim_config.n_worlds = 20
21-
10+
enable_cache()
2211
SEED = 42
23-
2412
# Create environment that contains a figure eight trajectory. You can parametrize the
2513
# observation space, i.e., which part of the trajectory is contained in the observation. Please
2614
# refer to the documentation of the environment for more information.
2715
envs = gymnasium.make_vec(
2816
"DroneFigureEightTrajectory-v0",
29-
n_trajectory_sample_points=10,
30-
dt_trajectory_sample_points=0.1,
17+
num_envs=20,
18+
freq=50,
19+
n_samples=10,
20+
samples_dt=0.1,
3121
trajectory_time=10.0,
32-
render_trajectory_sample=True, # useful for debug purposes
33-
time_horizon_in_seconds=10.0,
34-
num_envs=sim_config.n_worlds,
35-
**sim_config,
22+
render_samples=True,
3623
)
3724

3825
# RL wrapper to clip the actions to [-1, 1] and rescale them for use with common DRL libraries.
@@ -43,19 +30,12 @@ def main():
4330
envs = JaxToNumpy(envs)
4431

4532
# dummy action for going up (in attitude control)
46-
action = np.zeros((sim_config.n_worlds * sim_config.n_drones, 4), dtype=np.float32)
47-
action[..., 0] = 0.34
33+
action = np.zeros((20, 4), dtype=np.float32)
34+
action[..., 0] = 0.31
4835

4936
obs, info = envs.reset(seed=SEED)
50-
51-
# The trajectory is defined as a scipy spline. Its parameter can be retrieved using
52-
# `envs.unwrapped.tck`. The spline can be reconstructed using scipy's splev.
53-
spline_params = envs.unwrapped.tck
54-
tau = envs.unwrapped.tau # 1D parameters of the spline for the current timestep, in [0,1]
55-
value = splev(tau, spline_params) # noqa: F841, used for demonstration purposes
56-
5737
# Step through the environment
58-
for _ in range(1500):
38+
for _ in range(500):
5939
observation, reward, terminated, truncated, info = envs.step(action)
6040
envs.render()
6141

0 commit comments

Comments
 (0)