Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634412089
  • Loading branch information
Jake VanderPlas authored and t5-copybara committed May 16, 2024
1 parent 1938bf8 commit a42d50b
Show file tree
Hide file tree
Showing 17 changed files with 107 additions and 106 deletions.
10 changes: 5 additions & 5 deletions t5x/adafactor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def check_eq(xs, ys, atol=None, rtol=None):
ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys)
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
assert jax.tree_util.tree_all(
jax.tree_map(
jax.tree.map(
lambda x, y: np.array(x).shape == np.array(y).shape,
xs_leaves,
ys_leaves,
)
), "Leaves' shapes don't match."
assert jax.tree_map(
assert jax.tree.map(
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
xs_leaves,
ys_leaves,
Expand All @@ -77,11 +77,11 @@ def flattened_state_dict(x):


def tree_shape(x):
return jax.tree_map(jnp.shape, x)
return jax.tree.map(jnp.shape, x)


def tree_equals(x, y):
return jax.tree_util.tree_all(jax.tree_map(operator.eq, x, y))
return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y))


def _get_multi_adafactor(
Expand Down Expand Up @@ -483,7 +483,7 @@ def test_standard_factor_rules():

# create fake model parameters
k = jax.random.PRNGKey(0)
params = jax.tree_map(
params = jax.tree.map(
lambda shape: jax.random.uniform(k, shape),
MODEL_SHAPE,
is_leaf=lambda x: isinstance(x, list),
Expand Down
2 changes: 1 addition & 1 deletion t5x/checkpoint_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def restore_from_t5_checkpoint(
t5_data = _maybe_correct_relpos_bias(t5_data)
state_dict = _update_state_dict(state_dict, t5_data, strict=strict)
if not lazy_parameters:
state_dict = jax.tree_map(
state_dict = jax.tree.map(
lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict
)
return state_dict
6 changes: 3 additions & 3 deletions t5x/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def cache_map(fn, cache, apply_to_index: bool = False):
exclusion_list = ['cached_bias', 'position_embedder_index']
keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list}

keyvals = jax.tree_map(fn, keyvals)
keyvals = jax.tree.map(fn, keyvals)
flat_cache.update(keyvals)
new_cache = traverse_util.unflatten_dict(flat_cache)
if frozen:
Expand Down Expand Up @@ -901,7 +901,7 @@ def gather_beams(
def gather_fn(x):
return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype)

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)
else:
# True gather via fancy indexing.
batch_indices = jnp.reshape(
Expand All @@ -912,7 +912,7 @@ def gather_fn(x):
def gather_fn(x):
return x[batch_indices, beam_indices]

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def top_k_two_stage(x, k):
Expand Down
4 changes: 2 additions & 2 deletions t5x/decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def test_cache_map(self):
},
}

jax.tree_map(
jax.tree.map(
np.testing.assert_array_equal, decoding.cache_map(fn, cache), gold_cache
)

Expand Down Expand Up @@ -956,7 +956,7 @@ def test_cache_map_with_index(self):
},
}

jax.tree_map(
jax.tree.map(
np.testing.assert_array_equal,
decoding.cache_map(fn, cache, apply_to_index=True),
gold_cache,
Expand Down
8 changes: 4 additions & 4 deletions t5x/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _json_compat(value):
with gfile.GFile(path, 'w') as f:
for i, inp in task_ds.enumerate().as_numpy_iterator():
predictions = all_predictions[i]
aux_values = jax.tree_map(
aux_values = jax.tree.map(
f=lambda v, i=i: v[i],
tree=all_aux_values,
is_leaf=lambda v: isinstance(v, (np.ndarray, list)),
Expand Down Expand Up @@ -311,7 +311,7 @@ def _json_compat(value):
elif mode == 'score':
json_dict['score'] = _json_compat(predictions)
if aux_values:
json_dict['aux'] = jax.tree_map(_json_compat, aux_values)
json_dict['aux'] = jax.tree.map(_json_compat, aux_values)
elif mode == 'predict_with_aux':
assert vocabulary is not None
json_dict['prediction'] = _json_compat(
Expand All @@ -322,7 +322,7 @@ def _json_compat(value):
# Truncate padding tokens.
pred = pred[: pred.index(0)] if 0 in pred else pred
json_dict['prediction_tokens'] = pred
json_dict['aux'] = jax.tree_map(_json_compat, aux_values)
json_dict['aux'] = jax.tree.map(_json_compat, aux_values)
else:
raise ValueError(f'Invalid mode: {mode}')
json_str = json.dumps(json_dict, cls=json_encoder_cls)
Expand Down Expand Up @@ -353,7 +353,7 @@ def _extract_tokens_and_aux_values(inference_fn_outputs) -> Inferences:
permutation = np.argsort(indices)
permute = lambda v: [v[permutation[i]] for i in range(len(permutation))]
tokens = permute(tokens)
all_aux_values = jax.tree_map(
all_aux_values = jax.tree.map(
f=permute,
tree=all_aux_values,
is_leaf=lambda v: isinstance(v, (np.ndarray, list)),
Expand Down
4 changes: 2 additions & 2 deletions t5x/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def fn(o):
else:
return o

return jax.tree_map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time))
return jax.tree.map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time))


def set_step_metrics_num_steps(metrics, num_steps):
Expand All @@ -310,4 +310,4 @@ def fn(o):
else:
return o

return jax.tree_map(fn, metrics, is_leaf=is_metric_obj)
return jax.tree.map(fn, metrics, is_leaf=is_metric_obj)
20 changes: 10 additions & 10 deletions t5x/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class OptaxStatePartitionRules:
optax.InjectHyperparamsState: (
lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray
count=None,
hyperparams=jax.tree_map(lambda x: None, state.hyperparams),
hyperparams=jax.tree.map(lambda x: None, state.hyperparams),
inner_state=OptaxStatePartitionRules.derive_optax_logical_axes(
state.inner_state, params_axes
),
Expand Down Expand Up @@ -438,7 +438,7 @@ def derive_logical_axes(self, optimizer, param_logical_axes):
An `optimizers.Optimizer` instance, with all the leafs replaced by t5x
PartitionSpec or None (no partition).
"""
optimizer_logical_axes = jax.tree_map(
optimizer_logical_axes = jax.tree.map(
lambda x: None, optimizer.state_dict()
)
optimizer_logical_axes['target'] = param_logical_axes
Expand Down Expand Up @@ -698,7 +698,7 @@ def __init__(
self.sub_optimizers = sub_optimizers

def init_state(self, params):
param_states = jax.tree_map(lambda x: _Marker(), params)
param_states = jax.tree.map(lambda x: _Marker(), params)
overlap = False
for idx, traversal in enumerate(self.traversals):
for match in traversal.iterate(param_states):
Expand All @@ -707,26 +707,26 @@ def init_state(self, params):
if overlap:
raise ValueError(
'Multiple optimizers match the same leaves : '
+ str(jax.tree_map(lambda match: match._indices, param_states)) # pylint: disable=protected-access
+ str(jax.tree.map(lambda match: match._indices, param_states)) # pylint: disable=protected-access
)

param_states = jax.tree_map(lambda x: _Marker(), params)
param_states = jax.tree.map(lambda x: _Marker(), params)
for focus, opt_def in zip(self.traversals, self.sub_optimizers):
ps = _subtree_from_traversal(focus, params)
ss = opt_def.init_state(ps)
param_states = _update_subtree_of_traversal(
focus, param_states, ss.param_states
)
# Update state to None when param is not optimized by any sub optimizer.
param_states = jax.tree_map(
param_states = jax.tree.map(
lambda x: (None if isinstance(x, _Marker) else x), param_states
)
return OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states)

def apply_gradient(self, hyper_params, params, state, grads):
new_params = params
it = zip(self.traversals, self.sub_optimizers, hyper_params)
new_param_states = jax.tree_map(lambda x: _Marker(), params)
new_param_states = jax.tree.map(lambda x: _Marker(), params)
for focus, opt_def, hp in it:
ps = _subtree_from_traversal(focus, params)
gs = _subtree_from_traversal(focus, grads)
Expand All @@ -738,7 +738,7 @@ def apply_gradient(self, hyper_params, params, state, grads):
focus, new_param_states, new_ss.param_states
)
# Update state to None when param is not optimized by any sub optimizer.
new_param_states = jax.tree_map(
new_param_states = jax.tree.map(
lambda x: (None if isinstance(x, _Marker) else x), new_param_states
)
return new_params, OptimizerState(state.step + 1, new_param_states)
Expand Down Expand Up @@ -772,7 +772,7 @@ def set_param_axes(self, param_logical_axes):

def derive_logical_axes(self, optimizer, param_logical_axes):
"""Derives optimizer logical partitioning from model logical partitions."""
param_states = jax.tree_map(
param_states = jax.tree.map(
lambda x: _Marker(), optimizer.state.param_states
)
for focus, opt_def in zip(self.traversals, self.sub_optimizers):
Expand All @@ -786,7 +786,7 @@ def derive_logical_axes(self, optimizer, param_logical_axes):
focus, param_states, new_opt.state.param_states
)
# Update axes to None when param is not optimized by any sub optimizer.
param_states = jax.tree_map(
param_states = jax.tree.map(
lambda x: (None if isinstance(x, _Marker) else x), param_states
)
return Optimizer(
Expand Down
16 changes: 8 additions & 8 deletions t5x/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def check_eq(xs, ys, atol=None, rtol=None):
ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys)
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
assert jax.tree_util.tree_all(
jax.tree_map(
jax.tree.map(
lambda x, y: np.array(x).shape == np.array(y).shape,
xs_leaves,
ys_leaves,
)
), "Leaves' shapes don't match."
assert jax.tree_map(
assert jax.tree.map(
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
xs_leaves,
ys_leaves,
Expand All @@ -73,11 +73,11 @@ def flattened_state_dict(x):


def tree_shape(x):
return jax.tree_map(jnp.shape, x)
return jax.tree.map(jnp.shape, x)


def tree_equals(x, y):
return jax.tree_util.tree_all(jax.tree_map(operator.eq, x, y))
return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y))


def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__):
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_params(cls):

@classmethod
def get_params_shapes(cls):
return jax.tree_map(jnp.shape, cls.get_params())
return jax.tree.map(jnp.shape, cls.get_params())

@classmethod
def get_param_logical_axes(cls):
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_adamw_state_serialization(self):
state_dict = optimizer.state_dict()

chex.assert_trees_all_equal(
frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)),
frozen_dict.FrozenDict(jax.tree.map(jnp.shape, state_dict)),
frozen_dict.FrozenDict({
'target': self.get_params_shapes(),
'state': {
Expand Down Expand Up @@ -292,8 +292,8 @@ def run_train_loop(self, optimizer_def):

learning_rate_fn = utils.create_learning_rate_scheduler()

input_shapes = jax.tree_map(jnp.shape, first_batch)
input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch)
input_shapes = jax.tree.map(jnp.shape, first_batch)
input_types = jax.tree.map(lambda x: jnp.dtype(x.dtype), first_batch)

partitioner = partitioning.PjitPartitioner(
num_partitions=2,
Expand Down
5 changes: 3 additions & 2 deletions t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,8 @@ def get_logical_axes(self, train_state: TrainState) -> TrainState:
"""Returns a copy of TrainState with Optional[AxisNames] as leaves."""
# By default, return None for the logical axes.
return train_state.restore_state(
jax.tree_map(lambda x: None, train_state.state_dict()))
jax.tree.map(lambda x: None, train_state.state_dict())
)

def get_mesh_axes(self, train_state: TrainState) -> TrainState:
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
Expand Down Expand Up @@ -1136,7 +1137,7 @@ def _logical_to_mesh_axes(param_name, logical_axes):
# arr_tree is a PyTree of jax.Array or np.ndarray and
# pspecs is PyTree[jax.sharding.PartitionSpec]
def host_local_array_to_global_array(arr_tree, mesh: jax.sharding.Mesh, pspecs):
pspecs = jax.tree_map(
pspecs = jax.tree.map(
lambda x: PartitionSpec() if x is None else x,
pspecs,
is_leaf=lambda x: x is None,
Expand Down
8 changes: 4 additions & 4 deletions t5x/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_get_mesh_axes(self):
adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None),
adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None),
)
jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec)
jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec)

axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=True)
expected_axes_spec = self.get_expected_axes_spec(
Expand All @@ -263,7 +263,7 @@ def test_get_mesh_axes(self):
m=p1_spec, v=None, v_col=None, v_row=None
),
)
jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec)
jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec)

axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=True)
expected_axes_spec = self.get_expected_axes_spec(
Expand All @@ -274,7 +274,7 @@ def test_get_mesh_axes(self):
m=p1_spec, v=p1_spec, v_col=None, v_row=None
),
)
jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec)
jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec)

axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=False)
expected_axes_spec = self.get_expected_axes_spec(
Expand All @@ -285,7 +285,7 @@ def test_get_mesh_axes(self):
m=None, v=p1_spec, v_col=None, v_row=None
),
)
jax.tree_map(self.assertEqual, axes_spec, expected_axes_spec)
jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec)

@parameterized.product(activation_dims=(1, 2), param_dims=(1, 2))
def test_standard_logical_axis_rules(self, activation_dims, param_dims):
Expand Down
6 changes: 3 additions & 3 deletions t5x/precompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def precompile(
)

# Need to use full batch size.
input_shapes = jax.tree_map(
input_shapes = jax.tree.map(
lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec
)
input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec)
dummy_batch = jax.tree_map(
input_types = jax.tree.map(lambda x: x.dtype, train_iter.element_spec)
dummy_batch = jax.tree.map(
lambda x: np.ones(x.shape, x.dtype), train_iter.element_spec
)

Expand Down
Loading

0 comments on commit a42d50b

Please sign in to comment.