@@ -47,7 +47,6 @@ class CrazyflowBaseEnv(VectorEnv):
47
47
def __init__ (
48
48
self ,
49
49
* ,
50
- jax_random_key : int , # required for jax random number generator
51
50
num_envs : int = 1 , # required for VectorEnv
52
51
max_episode_steps : int = 1000 ,
53
52
return_datatype : Literal ["numpy" , "jax" ] = "jax" ,
@@ -56,7 +55,6 @@ def __init__(
56
55
"""Summary: Initializes the CrazyflowEnv.
57
56
58
57
Args:
59
- jax_random_key: The random key for the jax random number generator.
60
58
num_envs: The number of environments to run in parallel.
61
59
max_episode_steps: The maximum number of steps per episode.
62
60
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
@@ -66,12 +64,14 @@ def __init__(
66
64
"""
67
65
assert num_envs == kwargs ["n_worlds" ], "num_envs must be equal to n_worlds"
68
66
69
- self .jax_key = jax .random .key (jax_random_key )
67
+ # Set random initial seed for JAX. For seeding, people should use the reset function
68
+ jax_seed = int (self .np_random .random () * 2 ** 32 )
69
+ self .jax_key = jax .random .key (jax_seed )
70
70
71
71
self .num_envs = num_envs
72
72
self .return_datatype = return_datatype
73
73
self .device = jax .devices (kwargs ["device" ])[0 ]
74
- self .max_episode_steps = jnp . array ( max_episode_steps , dtype = jnp . int32 , device = self . device )
74
+ self .max_episode_steps = max_episode_steps
75
75
76
76
self .sim = Sim (** kwargs )
77
77
@@ -83,7 +83,7 @@ def __init__(
83
83
"Simulation frequency should be a multiple of control frequency. We can handle the other case, but we highly recommend to change the simulation frequency to a multiple of the control frequency."
84
84
)
85
85
86
- self .n_substeps = jnp . array ( self .sim .freq // self .sim .control_freq )
86
+ self .n_substeps = self .sim .freq // self .sim .control_freq
87
87
88
88
self .prev_done = jnp .zeros ((self .sim .n_worlds ), dtype = jnp .bool_ , device = self .device )
89
89
@@ -111,44 +111,40 @@ def __init__(
111
111
112
112
def step (self , action : Array ) -> tuple [Array , Array , Array , Array , dict ]:
113
113
assert self .action_space .contains (action ), f"{ action !r} ({ type (action )} ) invalid"
114
- action = jnp .array (action , device = self .device ).reshape (
115
- (self .sim .n_worlds , self .sim .n_drones , - 1 )
116
- )
117
-
114
+ action = self ._sanitize_action (action , self .sim .n_worlds , self .sim .n_drones , self .device )
118
115
action = self ._rescale_action (action , self .sim .control )
119
116
120
- if self .sim .control == Control . state :
121
- raise NotImplementedError (
122
- "Possibly you want to control state differences instead of absolute states"
123
- )
124
- self . sim . state_control ( action )
125
- elif self . sim . control == Control .attitude :
126
- self .sim .attitude_control (action )
127
- elif self . sim . control == Control .thrust :
128
- self .sim .thrust_control (action )
129
- else :
130
- raise ValueError (f"Invalid control type { self .sim .control } " )
117
+ match self .sim .control :
118
+ case Control . state :
119
+ raise NotImplementedError (
120
+ "Possibly you want to control state differences instead of absolute states"
121
+ )
122
+ case Control .attitude :
123
+ self .sim .attitude_control (action )
124
+ case Control .thrust :
125
+ self .sim .thrust_control (action )
126
+ case _ :
127
+ raise ValueError (f"Invalid control type { self .sim .control } " )
131
128
132
129
for _ in range (self .n_substeps ):
133
130
self .sim .step ()
134
-
135
131
# Reset all environments which terminated or were truncated in the last step
136
132
if jnp .any (self .prev_done ):
137
133
self .reset (mask = self .prev_done )
138
134
139
- reward = self .reward
140
135
terminated = self .terminated
141
136
truncated = self .truncated
137
+ self .prev_done = self ._done (terminated , truncated )
142
138
143
- self .prev_done = jnp .logical_or (terminated , truncated )
139
+ convert = self .return_datatype == "numpy"
140
+ terminated = maybe_to_numpy (terminated , convert )
141
+ truncated = maybe_to_numpy (truncated , convert )
142
+ return self ._obs (), self .reward , terminated , truncated , {}
144
143
145
- return (
146
- self ._get_obs (),
147
- reward ,
148
- maybe_to_numpy (terminated , self .return_datatype == "numpy" ),
149
- maybe_to_numpy (truncated , self .return_datatype == "numpy" ),
150
- {},
151
- )
144
+ @staticmethod
145
+ @partial (jax .jit , static_argnames = ["n_worlds" , "n_drones" , "device" ])
146
+ def _sanitize_action (action : Array , n_worlds : int , n_drones : int , device : str ) -> Array :
147
+ return jnp .array (action , device = device ).reshape ((n_worlds , n_drones , - 1 ))
152
148
153
149
@staticmethod
154
150
@partial (jax .jit , static_argnames = ["control_type" ])
@@ -167,14 +163,19 @@ def _rescale_action(action: Array, control_type: str) -> Array:
167
163
raise NotImplementedError (
168
164
f"Rescaling not implemented for control type '{ control_type } '"
169
165
)
170
-
171
166
return action * params .scale_factor + params .mean
172
167
168
+ @staticmethod
169
+ @jax .jit
170
+ def _done (terminated : Array , truncated : Array ) -> Array :
171
+ return jnp .logical_or (terminated , truncated )
172
+
173
173
def reset_all (
174
174
self , * , seed : int | None = None , options : dict | None = None
175
175
) -> tuple [dict [str , Array ], dict ]:
176
176
super ().reset (seed = seed )
177
-
177
+ if seed is not None :
178
+ self .jax_key = jax .random .key (seed )
178
179
# Resets ALL (!) environments
179
180
if options is None :
180
181
options = {}
@@ -183,7 +184,7 @@ def reset_all(
183
184
184
185
self .prev_done = jnp .zeros ((self .sim .n_worlds ), dtype = jnp .bool_ )
185
186
186
- return self ._get_obs (), {}
187
+ return self ._obs (), {}
187
188
188
189
def reset (self , mask : Array ) -> None :
189
190
self .sim .reset (mask = mask )
@@ -241,26 +242,21 @@ def _terminated(dones: Array, states: SimState, contacts: Array) -> Array:
241
242
return jnp .where (dones , False , terminated )
242
243
243
244
@staticmethod
244
- @jax .jit
245
- def _truncated (
246
- dones : Array , steps : Array , max_episode_steps : Array , n_substeps : Array
247
- ) -> Array :
245
+ @partial (jax .jit , static_argnames = ["max_episode_steps" , "n_substeps" ])
246
+ def _truncated (dones : Array , steps : Array , max_episode_steps : int , n_substeps : int ) -> Array :
248
247
truncated = steps / n_substeps >= max_episode_steps
249
248
return jnp .where (dones , False , truncated )
250
249
251
250
def render (self ):
252
251
self .sim .render ()
253
252
254
- def _get_obs (self ) -> dict [str , Array ]:
255
- obs = {
256
- state : maybe_to_numpy (
257
- getattr (self .sim .states , state )[..., 2 ]
258
- if state == "pos"
259
- else getattr (self .sim .states , state ),
260
- self .return_datatype == "numpy" ,
261
- )
262
- for state in self .states_to_include_in_obs
263
- }
253
+ def _obs (self ) -> dict [str , Array ]:
254
+ convert = self .return_datatype == "numpy"
255
+ fields = self .states_to_include_in_obs
256
+ states = [maybe_to_numpy (getattr (self .sim .states , field ), convert ) for field in fields ]
257
+ obs = {k : v for k , v in zip (fields , states )}
258
+ if "pos" in obs :
259
+ obs ["pos" ] = obs ["pos" ][..., 2 ]
264
260
return obs
265
261
266
262
@@ -276,8 +272,7 @@ def __init__(self, **kwargs: dict):
276
272
- jnp .inf , jnp .inf , shape = (self ._obs_size ,), dtype = jnp .float32
277
273
)
278
274
self .observation_space = batch_space (self .single_observation_space , self .sim .n_worlds )
279
-
280
- self .goal = jnp .zeros ((kwargs ["n_worlds" ], 3 ), dtype = jnp .float32 )
275
+ self .goal = jnp .zeros ((kwargs ["n_worlds" ], 3 ), dtype = jnp .float32 , device = self .device )
281
276
282
277
@property
283
278
def reward (self ) -> Array :
@@ -303,8 +298,8 @@ def reset(self, mask: Array) -> None:
303
298
)
304
299
self .goal = self .goal .at [mask ].set (new_goals [mask ])
305
300
306
- def _get_obs (self ) -> dict [str , Array ]:
307
- obs = super ()._get_obs ()
301
+ def _obs (self ) -> dict [str , Array ]:
302
+ obs = super ()._obs ()
308
303
obs ["difference_to_goal" ] = [self .goal - self .sim .states .pos ]
309
304
return obs
310
305
@@ -348,7 +343,7 @@ def reset(self, mask: Array) -> None:
348
343
)
349
344
self .target_vel = self .target_vel .at [mask ].set (new_target_vel [mask ])
350
345
351
- def _get_obs (self ) -> dict [str , Array ]:
352
- obs = super ()._get_obs ()
346
+ def _obs (self ) -> dict [str , Array ]:
347
+ obs = super ()._obs ()
353
348
obs ["difference_to_target_vel" ] = [self .target_vel - self .sim .states .vel ]
354
349
return obs
0 commit comments