You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -48,7 +41,7 @@ class CrazyflowBaseEnv(VectorEnv):
48
41
def__init__(
49
42
self,
50
43
*,
51
-
jax_random_key, # required for jax random number generator
44
+
jax_random_key: int, # required for jax random number generator
52
45
num_envs: int=1, # required for VectorEnv
53
46
max_episode_steps: int=1000,
54
47
return_datatype: Literal["numpy", "jax"] ="jax",
@@ -57,9 +50,13 @@ def __init__(
57
50
"""Summary: Initializes the CrazyflowEnv.
58
51
59
52
Args:
60
-
max_episode_steps (int): The maximum number of steps per episode.
61
-
return_datatype (Literal["numpy", "jax"]): The data type for returned arrays, either "numpy" or "jax". If specified as "numpy", the returned arrays will be numpy arrays on the CPU. If specified as "jax", the returned arrays will be jax arrays on the "device" specifiedf or the simulation.
62
-
**kwargs: Takes arguments that are passed to the Crazyfly simulation .
53
+
jax_random_key: The random key for the jax random number generator.
54
+
num_envs: The number of environments to run in parallel.
55
+
max_episode_steps: The maximum number of steps per episode.
56
+
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
57
+
the returned arrays will be numpy arrays on the CPU. If "jax", the returned arrays
58
+
will be jax arrays on the "device" specified for the simulation.
59
+
**kwargs: Takes arguments that are passed to the Crazyfly simulation.
63
60
"""
64
61
assertnum_envs==kwargs["n_worlds"], "num_envs must be equal to n_worlds"
0 commit comments