@@ -271,7 +271,9 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
271271 # Reports incorrect usage without `chexify()`.
272272 with self .assertRaisesRegex (
273273 RuntimeError , 'can only be called from functions wrapped .*chexify' ):
274- jax_transform (fn_value_assert )(* valid_args )
274+ # Create a local object to avoid reusing jax internal cache.
275+ local_fn_value_assert = make_test_fn (chex_value_assert_positive )
276+ jax_transform (local_fn_value_assert )(* valid_args )
275277
276278 # Run tests with invalid arguments.
277279 for invalid_args , label in zip (all_invalid_args , failure_labels ):
@@ -289,7 +291,9 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
289291 # Reports incorrect usage without `chexify()`.
290292 with self .assertRaisesRegex (
291293 RuntimeError , 'can only be called from functions wrapped .*chexify' ):
292- jax_transform (fn_value_assert )(* invalid_args )
294+ # Create a local object to avoid reusing jax internal cache.
295+ local_fn_value_assert = make_test_fn (chex_value_assert_positive )
296+ jax_transform (local_fn_value_assert )(* invalid_args )
293297
294298 def run_test_suite_with_log_abs_fn (self , make_log_fn , jax_transform , devices ,
295299 run_pure , run_in_thread ):
0 commit comments