From a42d50b619f3d4503b01b96f0c6002f2338109a5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 May 2024 08:46:16 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634412089 --- t5x/adafactor_test.py | 10 +++---- t5x/checkpoint_importer.py | 2 +- t5x/decoding.py | 6 ++-- t5x/decoding_test.py | 4 +-- t5x/infer.py | 8 ++--- t5x/metrics.py | 4 +-- t5x/optimizers.py | 20 ++++++------- t5x/optimizers_test.py | 16 +++++----- t5x/partitioning.py | 5 ++-- t5x/partitioning_test.py | 8 ++--- t5x/precompile.py | 6 ++-- t5x/test_utils.py | 12 ++++---- t5x/train.py | 6 ++-- t5x/train_state.py | 2 +- t5x/train_state_test.py | 60 +++++++++++++++++++------------------- t5x/trainer_test.py | 32 ++++++++++---------- t5x/utils_test.py | 12 ++++---- 17 files changed, 107 insertions(+), 106 deletions(-) diff --git a/t5x/adafactor_test.py b/t5x/adafactor_test.py index df8f07e38..59cada411 100644 --- a/t5x/adafactor_test.py +++ b/t5x/adafactor_test.py @@ -58,13 +58,13 @@ def check_eq(xs, ys, atol=None, rtol=None): ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys) assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" assert jax.tree_util.tree_all( - jax.tree_map( + jax.tree.map( lambda x, y: np.array(x).shape == np.array(y).shape, xs_leaves, ys_leaves, ) ), "Leaves' shapes don't match." - assert jax.tree_map( + assert jax.tree.map( functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), xs_leaves, ys_leaves, @@ -77,11 +77,11 @@ def flattened_state_dict(x): def tree_shape(x): - return jax.tree_map(jnp.shape, x) + return jax.tree.map(jnp.shape, x) def tree_equals(x, y): - return jax.tree_util.tree_all(jax.tree_map(operator.eq, x, y)) + return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y)) def _get_multi_adafactor( @@ -483,7 +483,7 @@ def test_standard_factor_rules(): # create fake model parameters k = jax.random.PRNGKey(0) - params = jax.tree_map( + params = jax.tree.map( lambda shape: jax.random.uniform(k, shape), MODEL_SHAPE, is_leaf=lambda x: isinstance(x, list), diff --git a/t5x/checkpoint_importer.py b/t5x/checkpoint_importer.py index 8baa09f99..312ecf5df 100644 --- a/t5x/checkpoint_importer.py +++ b/t5x/checkpoint_importer.py @@ -545,7 +545,7 @@ def restore_from_t5_checkpoint( t5_data = _maybe_correct_relpos_bias(t5_data) state_dict = _update_state_dict(state_dict, t5_data, strict=strict) if not lazy_parameters: - state_dict = jax.tree_map( + state_dict = jax.tree.map( lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict ) return state_dict diff --git a/t5x/decoding.py b/t5x/decoding.py index 622ef40e4..fc5eaec96 100644 --- a/t5x/decoding.py +++ b/t5x/decoding.py @@ -762,7 +762,7 @@ def cache_map(fn, cache, apply_to_index: bool = False): exclusion_list = ['cached_bias', 'position_embedder_index'] keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list} - keyvals = jax.tree_map(fn, keyvals) + keyvals = jax.tree.map(fn, keyvals) flat_cache.update(keyvals) new_cache = traverse_util.unflatten_dict(flat_cache) if frozen: @@ -901,7 +901,7 @@ def gather_beams( def gather_fn(x): return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype) - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) else: # True gather via fancy indexing. batch_indices = jnp.reshape( @@ -912,7 +912,7 @@ def gather_fn(x): def gather_fn(x): return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def top_k_two_stage(x, k): diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index 363053de2..13bf10fc9 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -891,7 +891,7 @@ def test_cache_map(self): }, } - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, decoding.cache_map(fn, cache), gold_cache ) @@ -956,7 +956,7 @@ def test_cache_map_with_index(self): }, } - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, decoding.cache_map(fn, cache, apply_to_index=True), gold_cache, diff --git a/t5x/infer.py b/t5x/infer.py index 98c1e5ff0..d993b2d70 100644 --- a/t5x/infer.py +++ b/t5x/infer.py @@ -273,7 +273,7 @@ def _json_compat(value): with gfile.GFile(path, 'w') as f: for i, inp in task_ds.enumerate().as_numpy_iterator(): predictions = all_predictions[i] - aux_values = jax.tree_map( + aux_values = jax.tree.map( f=lambda v, i=i: v[i], tree=all_aux_values, is_leaf=lambda v: isinstance(v, (np.ndarray, list)), @@ -311,7 +311,7 @@ def _json_compat(value): elif mode == 'score': json_dict['score'] = _json_compat(predictions) if aux_values: - json_dict['aux'] = jax.tree_map(_json_compat, aux_values) + json_dict['aux'] = jax.tree.map(_json_compat, aux_values) elif mode == 'predict_with_aux': assert vocabulary is not None json_dict['prediction'] = _json_compat( @@ -322,7 +322,7 @@ def _json_compat(value): # Truncate padding tokens. pred = pred[: pred.index(0)] if 0 in pred else pred json_dict['prediction_tokens'] = pred - json_dict['aux'] = jax.tree_map(_json_compat, aux_values) + json_dict['aux'] = jax.tree.map(_json_compat, aux_values) else: raise ValueError(f'Invalid mode: {mode}') json_str = json.dumps(json_dict, cls=json_encoder_cls) @@ -353,7 +353,7 @@ def _extract_tokens_and_aux_values(inference_fn_outputs) -> Inferences: permutation = np.argsort(indices) permute = lambda v: [v[permutation[i]] for i in range(len(permutation))] tokens = permute(tokens) - all_aux_values = jax.tree_map( + all_aux_values = jax.tree.map( f=permute, tree=all_aux_values, is_leaf=lambda v: isinstance(v, (np.ndarray, list)), diff --git a/t5x/metrics.py b/t5x/metrics.py index 9f21fa031..2ae9e4fec 100644 --- a/t5x/metrics.py +++ b/t5x/metrics.py @@ -298,7 +298,7 @@ def fn(o): else: return o - return jax.tree_map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time)) + return jax.tree.map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time)) def set_step_metrics_num_steps(metrics, num_steps): @@ -310,4 +310,4 @@ def fn(o): else: return o - return jax.tree_map(fn, metrics, is_leaf=is_metric_obj) + return jax.tree.map(fn, metrics, is_leaf=is_metric_obj) diff --git a/t5x/optimizers.py b/t5x/optimizers.py index 01ffd4552..2d663c8db 100644 --- a/t5x/optimizers.py +++ b/t5x/optimizers.py @@ -275,7 +275,7 @@ class OptaxStatePartitionRules: optax.InjectHyperparamsState: ( lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray count=None, - hyperparams=jax.tree_map(lambda x: None, state.hyperparams), + hyperparams=jax.tree.map(lambda x: None, state.hyperparams), inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( state.inner_state, params_axes ), @@ -438,7 +438,7 @@ def derive_logical_axes(self, optimizer, param_logical_axes): An `optimizers.Optimizer` instance, with all the leafs replaced by t5x PartitionSpec or None (no partition). """ - optimizer_logical_axes = jax.tree_map( + optimizer_logical_axes = jax.tree.map( lambda x: None, optimizer.state_dict() ) optimizer_logical_axes['target'] = param_logical_axes @@ -698,7 +698,7 @@ def __init__( self.sub_optimizers = sub_optimizers def init_state(self, params): - param_states = jax.tree_map(lambda x: _Marker(), params) + param_states = jax.tree.map(lambda x: _Marker(), params) overlap = False for idx, traversal in enumerate(self.traversals): for match in traversal.iterate(param_states): @@ -707,10 +707,10 @@ def init_state(self, params): if overlap: raise ValueError( 'Multiple optimizers match the same leaves : ' - + str(jax.tree_map(lambda match: match._indices, param_states)) # pylint: disable=protected-access + + str(jax.tree.map(lambda match: match._indices, param_states)) # pylint: disable=protected-access ) - param_states = jax.tree_map(lambda x: _Marker(), params) + param_states = jax.tree.map(lambda x: _Marker(), params) for focus, opt_def in zip(self.traversals, self.sub_optimizers): ps = _subtree_from_traversal(focus, params) ss = opt_def.init_state(ps) @@ -718,7 +718,7 @@ def init_state(self, params): focus, param_states, ss.param_states ) # Update state to None when param is not optimized by any sub optimizer. - param_states = jax.tree_map( + param_states = jax.tree.map( lambda x: (None if isinstance(x, _Marker) else x), param_states ) return OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) @@ -726,7 +726,7 @@ def init_state(self, params): def apply_gradient(self, hyper_params, params, state, grads): new_params = params it = zip(self.traversals, self.sub_optimizers, hyper_params) - new_param_states = jax.tree_map(lambda x: _Marker(), params) + new_param_states = jax.tree.map(lambda x: _Marker(), params) for focus, opt_def, hp in it: ps = _subtree_from_traversal(focus, params) gs = _subtree_from_traversal(focus, grads) @@ -738,7 +738,7 @@ def apply_gradient(self, hyper_params, params, state, grads): focus, new_param_states, new_ss.param_states ) # Update state to None when param is not optimized by any sub optimizer. - new_param_states = jax.tree_map( + new_param_states = jax.tree.map( lambda x: (None if isinstance(x, _Marker) else x), new_param_states ) return new_params, OptimizerState(state.step + 1, new_param_states) @@ -772,7 +772,7 @@ def set_param_axes(self, param_logical_axes): def derive_logical_axes(self, optimizer, param_logical_axes): """Derives optimizer logical partitioning from model logical partitions.""" - param_states = jax.tree_map( + param_states = jax.tree.map( lambda x: _Marker(), optimizer.state.param_states ) for focus, opt_def in zip(self.traversals, self.sub_optimizers): @@ -786,7 +786,7 @@ def derive_logical_axes(self, optimizer, param_logical_axes): focus, param_states, new_opt.state.param_states ) # Update axes to None when param is not optimized by any sub optimizer. - param_states = jax.tree_map( + param_states = jax.tree.map( lambda x: (None if isinstance(x, _Marker) else x), param_states ) return Optimizer( diff --git a/t5x/optimizers_test.py b/t5x/optimizers_test.py index f9d130eb4..e02e40d2f 100644 --- a/t5x/optimizers_test.py +++ b/t5x/optimizers_test.py @@ -54,13 +54,13 @@ def check_eq(xs, ys, atol=None, rtol=None): ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys) assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" assert jax.tree_util.tree_all( - jax.tree_map( + jax.tree.map( lambda x, y: np.array(x).shape == np.array(y).shape, xs_leaves, ys_leaves, ) ), "Leaves' shapes don't match." - assert jax.tree_map( + assert jax.tree.map( functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), xs_leaves, ys_leaves, @@ -73,11 +73,11 @@ def flattened_state_dict(x): def tree_shape(x): - return jax.tree_map(jnp.shape, x) + return jax.tree.map(jnp.shape, x) def tree_equals(x, y): - return jax.tree_util.tree_all(jax.tree_map(operator.eq, x, y)) + return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y)) def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__): @@ -160,7 +160,7 @@ def get_params(cls): @classmethod def get_params_shapes(cls): - return jax.tree_map(jnp.shape, cls.get_params()) + return jax.tree.map(jnp.shape, cls.get_params()) @classmethod def get_param_logical_axes(cls): @@ -249,7 +249,7 @@ def test_adamw_state_serialization(self): state_dict = optimizer.state_dict() chex.assert_trees_all_equal( - frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)), + frozen_dict.FrozenDict(jax.tree.map(jnp.shape, state_dict)), frozen_dict.FrozenDict({ 'target': self.get_params_shapes(), 'state': { @@ -292,8 +292,8 @@ def run_train_loop(self, optimizer_def): learning_rate_fn = utils.create_learning_rate_scheduler() - input_shapes = jax.tree_map(jnp.shape, first_batch) - input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch) + input_shapes = jax.tree.map(jnp.shape, first_batch) + input_types = jax.tree.map(lambda x: jnp.dtype(x.dtype), first_batch) partitioner = partitioning.PjitPartitioner( num_partitions=2, diff --git a/t5x/partitioning.py b/t5x/partitioning.py index 6537294d1..2ea7b9455 100644 --- a/t5x/partitioning.py +++ b/t5x/partitioning.py @@ -873,7 +873,8 @@ def get_logical_axes(self, train_state: TrainState) -> TrainState: """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" # By default, return None for the logical axes. return train_state.restore_state( - jax.tree_map(lambda x: None, train_state.state_dict())) + jax.tree.map(lambda x: None, train_state.state_dict()) + ) def get_mesh_axes(self, train_state: TrainState) -> TrainState: """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" @@ -1136,7 +1137,7 @@ def _logical_to_mesh_axes(param_name, logical_axes): # arr_tree is a PyTree of jax.Array or np.ndarray and # pspecs is PyTree[jax.sharding.PartitionSpec] def host_local_array_to_global_array(arr_tree, mesh: jax.sharding.Mesh, pspecs): - pspecs = jax.tree_map( + pspecs = jax.tree.map( lambda x: PartitionSpec() if x is None else x, pspecs, is_leaf=lambda x: x is None, diff --git a/t5x/partitioning_test.py b/t5x/partitioning_test.py index 66e02febf..199c27db6 100644 --- a/t5x/partitioning_test.py +++ b/t5x/partitioning_test.py @@ -252,7 +252,7 @@ def test_get_mesh_axes(self): adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), ) - jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec) + jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=True) expected_axes_spec = self.get_expected_axes_spec( @@ -263,7 +263,7 @@ def test_get_mesh_axes(self): m=p1_spec, v=None, v_col=None, v_row=None ), ) - jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec) + jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=True) expected_axes_spec = self.get_expected_axes_spec( @@ -274,7 +274,7 @@ def test_get_mesh_axes(self): m=p1_spec, v=p1_spec, v_col=None, v_row=None ), ) - jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec) + jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=False) expected_axes_spec = self.get_expected_axes_spec( @@ -285,7 +285,7 @@ def test_get_mesh_axes(self): m=None, v=p1_spec, v_col=None, v_row=None ), ) - jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec) + jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) @parameterized.product(activation_dims=(1, 2), param_dims=(1, 2)) def test_standard_logical_axis_rules(self, activation_dims, param_dims): diff --git a/t5x/precompile.py b/t5x/precompile.py index c01d5b2ee..eb7664b33 100644 --- a/t5x/precompile.py +++ b/t5x/precompile.py @@ -79,11 +79,11 @@ def precompile( ) # Need to use full batch size. - input_shapes = jax.tree_map( + input_shapes = jax.tree.map( lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec ) - input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) - dummy_batch = jax.tree_map( + input_types = jax.tree.map(lambda x: x.dtype, train_iter.element_spec) + dummy_batch = jax.tree.map( lambda x: np.ones(x.shape, x.dtype), train_iter.element_spec ) diff --git a/t5x/test_utils.py b/t5x/test_utils.py index 51a3b8bdd..960f240bd 100644 --- a/t5x/test_utils.py +++ b/t5x/test_utils.py @@ -80,7 +80,7 @@ def coords_to_idx(coords: Tuple[int, ...], bounds: Tuple[int, ...]) -> int: # Calculate stride multipliers. strides = tuple(itertools.accumulate((1,) + bounds[:-1], operator.mul)) # Sum linear index from strides and coords - return sum(jax.tree_map(lambda x, y: x * y, coords, strides)) + return sum(jax.tree.map(lambda x, y: x * y, coords, strides)) def make_devices( @@ -94,11 +94,11 @@ def make_devices( """Create mock TPU devices.""" devices = [] device_bounds = (nx, ny, nz, nc) - hnx, hny, hnz, hnc = jax.tree_map( + hnx, hny, hnz, hnc = jax.tree.map( lambda a, b: a // b, device_bounds, host_layout ) for x, y, z, c in itertools.product(*map(range, device_bounds)): - hx, hy, hz, hc = jax.tree_map( + hx, hy, hz, hc = jax.tree.map( lambda a, b: a // b, (x, y, z, c), host_layout ) # TODO(levskaya, jekbradbury): verify this id/host ordering on TPU v4 @@ -165,7 +165,7 @@ def make_train_state( global_input_shape, step=step, dtype=dtype ) - return jax.tree_map( + return jax.tree.map( functools.partial( create_sharded_array, global_shape=global_input_shape, @@ -332,7 +332,7 @@ def assert_equal(a, b): def assert_same(tree_a, tree_b): """Asserts that both trees are the same.""" tree_a, tree_b = jax.device_get((tree_a, tree_b)) - jax.tree_map(assert_equal, tree_a, tree_b) + jax.tree.map(assert_equal, tree_a, tree_b) def get_train_state_from_variables( @@ -384,7 +384,7 @@ def move_params_to_devices(self, train_state, train_state_axes): return train_state def get_mesh_axes(self, train_state): - mesh_axes = jax.tree_map(lambda _: self._mesh_axes, train_state) + mesh_axes = jax.tree.map(lambda _: self._mesh_axes, train_state) return mesh_axes.replace_step(None) def _local_chunker(self): diff --git a/t5x/train.py b/t5x/train.py index 873d66a0c..8bbe866a8 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -276,11 +276,11 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): partitioner=partitioner, data_layout=data_layout, ) - input_shapes = jax.tree_map( + input_shapes = jax.tree.map( lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec, ) - input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) + input_types = jax.tree.map(lambda x: x.dtype, train_iter.element_spec) if train_eval_dataset_cfg: _verify_matching_vocabs(train_eval_dataset_cfg) @@ -631,7 +631,7 @@ def _as_gda(spec): ) # Construct dummy batch for compiling the model. - dummy_batch = jax.tree_map(_as_gda, train_iter.element_spec) + dummy_batch = jax.tree.map(_as_gda, train_iter.element_spec) if not isinstance(dummy_batch, Mapping): raise ValueError( 'Training loop expects batches to have type ' diff --git a/t5x/train_state.py b/t5x/train_state.py index 04a86288e..4ede9ea18 100644 --- a/t5x/train_state.py +++ b/t5x/train_state.py @@ -308,7 +308,7 @@ def restore_state(self, state_dict: Mapping[str, Any]) -> 'InferenceState': def as_logical_axes(self) -> 'InferenceState': # Set step to None so that when the logical axes are processed by the # flax.partitioning.logical_to_mesh_axes function, it will be skipped - # because jax.tree_map will short circut and never call the function on the + # because jax.tree.map will short circut and never call the function on the # step. flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT return InferenceState( diff --git a/t5x/train_state_test.py b/t5x/train_state_test.py index 768a0112f..1e37d4674 100644 --- a/t5x/train_state_test.py +++ b/t5x/train_state_test.py @@ -47,8 +47,8 @@ def test_init(self): self.assertEqual( state.state_dict()['flax_mutables'], flax.core.unfreeze(flax_mutables) ) - jax.tree_map(np.testing.assert_array_equal, params, state.params) - jax.tree_map( + jax.tree.map(np.testing.assert_array_equal, params, state.params) + jax.tree.map( np.testing.assert_array_equal, optimizer.state.param_states, state.param_states, @@ -66,12 +66,12 @@ def test_create(self): self.assertEqual(state.step, 0) self.assertIsInstance(state._optimizer, optimizers.Optimizer) self.assertEqual(state._optimizer.optimizer_def, optimizer_def) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.flax_mutables, flax.core.freeze({'mutables': np.ones(3)}), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.params, model_variables['params'] ) self.assertIsNone(state.params_axes) @@ -107,10 +107,10 @@ def test_create_with_params_axes(self): }, ) self.assertEqual(state.flax_mutables, flax.core.freeze({})) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, model_variables['params'], state.params ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, model_variables['params_axes'], state.params_axes, @@ -160,15 +160,15 @@ def test_create_with_flax_mutables_axes(self): state.flax_mutables, flax.core.freeze({'grads': model_variables['grads']}), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, model_variables['params'], state.params ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, model_variables['params_axes'], state.params_axes, ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, model_variables['grads_axes'], state.flax_mutables_axes['grads'], @@ -212,10 +212,10 @@ def test_replace_params(self): new_params = {'test': np.zeros(10)} new_state = state.replace_params(new_params) - jax.tree_map(np.testing.assert_array_equal, new_params, new_state.params) + jax.tree.map(np.testing.assert_array_equal, new_params, new_state.params) expected_state_dict = state.state_dict() expected_state_dict['target'] = new_params - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, expected_state_dict, new_state.state_dict(), @@ -273,7 +273,7 @@ def test_as_logical_axes(self): ) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, axes_state.params, flax.core.freeze({ @@ -316,7 +316,7 @@ def test_as_logical_axes_with_flax_mutables(self): ) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, axes_state.flax_mutables, flax.core.freeze({ @@ -357,7 +357,7 @@ def test_as_logical_axes_with_flax_mutables_without_mutables_axes(self): axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) self.assertIsNone(axes_state.flax_mutables_axes) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, axes_state.flax_mutables, flax.core.freeze({}), @@ -381,7 +381,7 @@ def test_to_state_dict(self): state = train_state_lib.FlaxOptimTrainState.create( optimizer_def, model_variables ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.state_dict(), { @@ -438,17 +438,17 @@ def test_restore_state(self): self.assertEqual(restored.step, 1) self.assertIsInstance(restored._optimizer, optimizers.Optimizer) self.assertEqual(restored._optimizer.optimizer_def, optimizer_def) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.flax_mutables, flax.core.freeze({'mutables': np.zeros(3)}), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.params, flax.core.freeze({'kernel': np.ones((2, 4))}), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.param_states, flax.core.freeze({ @@ -457,7 +457,7 @@ def test_restore_state(self): ) }), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.params_axes, model_variables['params_axes'], @@ -476,7 +476,7 @@ def test_init(self): ) self.assertEqual(state.step, 1) self.assertEqual(state.flax_mutables, flax.core.unfreeze(flax_mutables)) - jax.tree_map(np.testing.assert_array_equal, params, state.params) + jax.tree.map(np.testing.assert_array_equal, params, state.params) self.assertIsNone(state.params_axes) def test_create(self): @@ -492,15 +492,15 @@ def test_create(self): }) state = train_state_lib.InferenceState.create(model_variables) self.assertEqual(state.step, 0) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.flax_mutables, flax.core.freeze({'mutables': np.ones(3)}), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.params, model_variables['params'] ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.params_axes, model_variables['params_axes'], @@ -527,7 +527,7 @@ def test_replace_params(self): new_params = {'test': np.zeros(10)} new_state = state.replace_params(new_params) - jax.tree_map(np.testing.assert_array_equal, new_params, new_state.params) + jax.tree.map(np.testing.assert_array_equal, new_params, new_state.params) def test_replace_step(self): model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) @@ -549,7 +549,7 @@ def test_as_logical_axes(self): state = train_state_lib.InferenceState.create(model_variables) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, axes_state.params, flax.core.freeze({ @@ -571,7 +571,7 @@ def test_to_state_dict(self): 'mutables': np.ones(3), }) state = train_state_lib.InferenceState.create(model_variables) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.state_dict(), { @@ -593,7 +593,7 @@ def test_to_state_dict_no_mutables(self): }, }) state = train_state_lib.InferenceState.create(model_variables) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, state.state_dict(), { @@ -621,12 +621,12 @@ def test_restore_state(self): restored = state.restore_state(state_dict) self.assertEqual(restored.step, 10) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.flax_mutables, flax.core.freeze(state_dict['flax_mutables']), ) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.params, flax.core.freeze(state_dict['target']), @@ -648,7 +648,7 @@ def test_restore_state_no_mutables_no_axes(self): self.assertEqual(restored.step, 10) self.assertEqual(restored.flax_mutables, train_state_lib.EMPTY_DICT) - jax.tree_map( + jax.tree.map( np.testing.assert_array_equal, restored.params, flax.core.freeze(state_dict['target']), diff --git a/t5x/trainer_test.py b/t5x/trainer_test.py index ecd708f21..ad9ce62d3 100644 --- a/t5x/trainer_test.py +++ b/t5x/trainer_test.py @@ -61,7 +61,7 @@ def _validate_events(test_case, summary_dir, expected_metrics, steps): else: actual_events[event.summary.value[0].tag] = float(tf.make_ndarray(tensor)) - jax.tree_map(test_case.assertAlmostEqual, actual_events, expected_metrics) + jax.tree.map(test_case.assertAlmostEqual, actual_events, expected_metrics) class MetricsManagerTest(absltest.TestCase): @@ -204,7 +204,7 @@ def fake_accum_grads( del model, num_microbatches, rng, data_partition_spec # Add `i` to each optimzer value. i = batch['i'].sum() - grad_accum = jax.tree_map(lambda x: i, optimizer) + grad_accum = jax.tree.map(lambda x: i, optimizer) # Add j to each metric. j = batch['j'].sum() metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} @@ -222,7 +222,7 @@ def fake_apply_grads( del weight_metrics_computer del other_state_variables metrics['learning_rate'] = clu.metrics.Average(learning_rate, count=1) - optimizer = jax.tree_map(lambda x, g: x + g, optimizer, grad_accum) + optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum) return optimizer, metrics @@ -269,7 +269,7 @@ def fake_grad_fn_without_weight_sum( target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, ) train_state = train_state_lib.FlaxOptimTrainState(optimizer) - grad_accum = jax.tree_map(lambda x: i, train_state) + grad_accum = jax.tree.map(lambda x: i, train_state) # Add j to each metric. j = batch['j'].sum() metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} @@ -311,7 +311,7 @@ def setUp(self): self.init_train_state = train_state_lib.FlaxOptimTrainState( self.init_optimizer ) - train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) + train_state_axes = jax.tree.map(lambda x: None, self.init_train_state) model_dir = self.create_tempdir().full_path mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} @@ -376,13 +376,13 @@ def _test_train(self, precompile): # 5.0 - 0.0 expected_metrics['timing/uptime'] = 5.0 # 0+1+2+3 = 6 - expected_train_state = jax.tree_map( + expected_train_state = jax.tree.map( lambda x: np.array(x + 6), self.init_train_state ) # Base rng must remain the same np.testing.assert_array_equal(trainer._base_rng, initial_rng) - jax.tree_map( + jax.tree.map( np.testing.assert_equal, trainer.train_state, expected_train_state ) # Expected step is 6 since we increment it along with the other optimizer @@ -866,7 +866,7 @@ def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): ) i = batch['i'].sum() - expected_grad_accum = jax.tree_map( + expected_grad_accum = jax.tree.map( lambda x: i, self.init_train_state ).params self.assertEqual(expected_grad_accum, grad_accum) @@ -920,7 +920,7 @@ def create_trainer(self, step, random_seed): target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, ) init_train_state = train_state_lib.FlaxOptimTrainState(init_optimizer) - train_state_axes = jax.tree_map(lambda x: None, init_train_state) + train_state_axes = jax.tree.map(lambda x: None, init_train_state) test_trainer = trainer_lib.Trainer( mock.create_autospec(models_lib.BaseModel, instance=True), @@ -944,7 +944,7 @@ def fake_accum_grads_rng( ): del model, batch, num_microbatches, data_partition_spec # Add 1, which will increment the step as a side effect. - grad_accum = jax.tree_map(lambda x: 1, optimizer) + grad_accum = jax.tree.map(lambda x: 1, optimizer) m = {'rng': metrics_lib.Sum(jnp.sum(jax.random.key_data(rng)))} return grad_accum, m, None @@ -979,7 +979,7 @@ def fake_mut_accum_grads( del model, num_microbatches, rng, data_partition_spec # Add `i` to each optimzer value. i = batch['i'].sum() - grad_accum = jax.tree_map(lambda x: i, optimizer) + grad_accum = jax.tree.map(lambda x: i, optimizer) # Add j to each metric. j = batch['j'].sum() metrics = { @@ -1001,7 +1001,7 @@ def fake_mut_apply_grads( metrics['learning_rate'] = clu.metrics.Average.from_model_output( learning_rate ) - optimizer = jax.tree_map(lambda x, g: x + g, optimizer, grad_accum) + optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum) return optimizer, metrics @@ -1025,7 +1025,7 @@ def setUp(self): } ), ) - train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) + train_state_axes = jax.tree.map(lambda x: None, self.init_train_state) model_dir = self.create_tempdir().full_path mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} @@ -1066,12 +1066,12 @@ def test_train(self, mock_time=None): batch = next(ds_iter) train_state, _ = trainer._partitioned_train_step(trainer.train_state, batch) - expected_train_state = jax.tree_map( + expected_train_state = jax.tree.map( lambda x: np.array(x + 1), self.init_train_state ) # Base rng must remain the same np.testing.assert_array_equal(trainer._base_rng, initial_rng) - jax.tree_map(np.testing.assert_equal, train_state, expected_train_state) + jax.tree.map(np.testing.assert_equal, train_state, expected_train_state) self.assertIsNone(trainer._compiled_train_step) self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) @@ -1092,7 +1092,7 @@ def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): ) i = batch['i'].sum() - expected_grad_accum = jax.tree_map( + expected_grad_accum = jax.tree.map( lambda x: i, self.init_train_state ).params self.assertEqual(expected_grad_accum, grad_accum) diff --git a/t5x/utils_test.py b/t5x/utils_test.py index d208ccf67..edcb464ae 100644 --- a/t5x/utils_test.py +++ b/t5x/utils_test.py @@ -276,7 +276,7 @@ def test_get_training_eval_datasets_task(self, mock_get_dataset): ) self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree_map( + jax.tree.map( np.testing.assert_equal, list(ds["mock_task"]), [ @@ -308,7 +308,7 @@ def test_get_training_eval_datasets_task(self, mock_get_dataset): ) self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree_map( + jax.tree.map( np.testing.assert_equal, list(ds["mock_task"]), [ @@ -340,7 +340,7 @@ def test_get_training_eval_datasets_task(self, mock_get_dataset): ) self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree_map( + jax.tree.map( np.testing.assert_equal, list(ds["mock_task"]), [ @@ -436,7 +436,7 @@ def test_get_training_eval_datasets_mixture(self, mock_get_dataset): res.keys(), ["mock_task1", "mock_task2", "mock_mix"] ) for ds in res.values(): - jax.tree_map( + jax.tree.map( np.testing.assert_equal, list(ds), [ @@ -518,7 +518,7 @@ def test_get_training_eval_datasets_mixture_obj(self, mock_get_dataset): res_obj.keys(), ["mock_task3", "mock_task4", "mock_mix2"] ) for ds in res_obj.values(): - jax.tree_map( + jax.tree.map( np.testing.assert_equal, list(ds), [ @@ -579,7 +579,7 @@ def test_override_params_axes_names(self): ], ) - jax.tree_map( + jax.tree.map( np.testing.assert_equal, overridden_variables, flax.core.freeze({