Skip to content

Commit 65ef2a3

Browse files
hbq1ChexDev
authored andcommitted
Create local objects in chexify tests to avoid reusing jax internal cache.
PiperOrigin-RevId: 503285098
1 parent ee8b69e commit 65ef2a3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

chex/_src/asserts_chexify_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
271271
# Reports incorrect usage without `chexify()`.
272272
with self.assertRaisesRegex(
273273
RuntimeError, 'can only be called from functions wrapped .*chexify'):
274-
jax_transform(fn_value_assert)(*valid_args)
274+
# Create a local object to avoid reusing jax internal cache.
275+
local_fn_value_assert = make_test_fn(chex_value_assert_positive)
276+
jax_transform(local_fn_value_assert)(*valid_args)
275277

276278
# Run tests with invalid arguments.
277279
for invalid_args, label in zip(all_invalid_args, failure_labels):
@@ -289,7 +291,9 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
289291
# Reports incorrect usage without `chexify()`.
290292
with self.assertRaisesRegex(
291293
RuntimeError, 'can only be called from functions wrapped .*chexify'):
292-
jax_transform(fn_value_assert)(*invalid_args)
294+
# Create a local object to avoid reusing jax internal cache.
295+
local_fn_value_assert = make_test_fn(chex_value_assert_positive)
296+
jax_transform(local_fn_value_assert)(*invalid_args)
293297

294298
def run_test_suite_with_log_abs_fn(self, make_log_fn, jax_transform, devices,
295299
run_pure, run_in_thread):

0 commit comments

Comments
 (0)