diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index 13bf10fc9..ebac83a67 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -21,7 +21,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax.experimental import host_callback as hcb +from jax.experimental import io_callback import jax.numpy as jnp import numpy as np from t5x import decoding @@ -156,12 +156,10 @@ def callback_fn(current_index_and_sequences): sequences[i, current_index[i] + 1] = EOS_ID return sequences - sequences = hcb.call( + sequences = io_callback( callback_fn, + jax.ShapeDtypeStruct(state.sequences.shape, state.sequences.dtype), (state.cur_index, state.sequences), - result_shape=jax.ShapeDtypeStruct( - state.sequences.shape, state.sequences.dtype - ), ) return state.replace(sequences=sequences)