File tree Expand file tree Collapse file tree 1 file changed +16
-6
lines changed Expand file tree Collapse file tree 1 file changed +16
-6
lines changed Original file line number Diff line number Diff line change @@ -176,9 +176,14 @@ def tree_max(tree: Any) -> chex.Numeric:
176176 Returns:
177177 a scalar value.
178178 """
179- maxes = jax .tree .map (jnp .max , tree )
180- # initializer=-jnp.inf should work but pytype wants a jax.Array.
181- return jax .tree .reduce (jnp .maximum , maxes , initializer = jnp .array (- jnp .inf ))
179+ identity = - float ("inf" )
180+ def f (array ):
181+ if jnp .size (array ) == 0 :
182+ return identity
183+ else :
184+ return jnp .max (array )
185+ maxes = jax .tree .map (f , tree )
186+ return jax .tree .reduce (jnp .maximum , maxes , initializer = identity )
182187
183188
184189def tree_min (tree : Any ) -> chex .Numeric :
@@ -190,9 +195,14 @@ def tree_min(tree: Any) -> chex.Numeric:
190195 Returns:
191196 a scalar value.
192197 """
193- mins = jax .tree .map (jnp .min , tree )
194- # initializer=jnp.inf should work but pytype wants a jax.Array.
195- return jax .tree .reduce (jnp .minimum , mins , initializer = jnp .array (jnp .inf ))
198+ identity = float ("inf" )
199+ def f (array ):
200+ if jnp .size (array ) == 0 :
201+ return identity
202+ else :
203+ return jnp .min (array )
204+ mins = jax .tree .map (f , tree )
205+ return jax .tree .reduce (jnp .minimum , mins , initializer = identity )
196206
197207
198208def tree_size (tree : Any ) -> int :
You can’t perform that action at this time.
0 commit comments