Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634070357
  • Loading branch information
Jake VanderPlas authored and t5-copybara committed May 15, 2024
1 parent 717cb3c commit 1938bf8
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 22 deletions.
14 changes: 7 additions & 7 deletions t5x/contrib/moe/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def shard_train_state(
mesh_axes: Optional[PartitionSpec],
) -> FlaxOptimTrainState:
"""Helper to construct a sharded train state from NumPy arrays."""
return jax.tree_map(
return jax.tree.map(
functools.partial(
create_sharded_array, global_mesh=global_mesh, mesh_axes=mesh_axes
),
Expand Down Expand Up @@ -374,7 +374,7 @@ def validate_restore(
restore_dtype=expected_restore_dtype,
)
if lazy_parameters:
actual_train_state = jax.tree_map(lambda x: x.get(), actual_train_state)
actual_train_state = jax.tree.map(lambda x: x.get(), actual_train_state)

# Validate.

Expand Down Expand Up @@ -418,12 +418,12 @@ def validate_restore(
},
}

jax.tree_map(
jax.tree.map(
np.testing.assert_array_equal,
actual_train_state.params,
expected_per_host_params,
)
jax.tree_map(
jax.tree.map(
np.testing.assert_array_equal,
actual_train_state.param_states,
expected_per_host_param_states,
Expand All @@ -442,8 +442,8 @@ def validate_restore(

self.assertTrue(
all(
jax.tree_leaves(
jax.tree_map(
jax.tree.leaves(
jax.tree.map(
lambda x: x.dtype == expected_restore_dtype,
actual_train_state.params,
)
Expand Down Expand Up @@ -519,7 +519,7 @@ def save(
i,
host_count,
num_partitions,
mesh_axes=jax.tree_map(lambda x: None, self.dense_model_mesh_axes)
mesh_axes=jax.tree.map(lambda x: None, self.dense_model_mesh_axes)
if disable_partitioning
else self.dense_model_mesh_axes,
)
Expand Down
8 changes: 4 additions & 4 deletions t5x/contrib/moe/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fake_accum_grads(
del model, num_microbatches, rng, data_partition_spec
# Add `i` to each optimzer value.
i = batch['i'].sum()
grad_accum = jax.tree_map(lambda x: i, optimizer)
grad_accum = jax.tree.map(lambda x: i, optimizer)
# Add j to each metric.
j = batch['j'].sum()
metrics = {
Expand All @@ -56,7 +56,7 @@ def fake_apply_grads(
del weight_metrics_computer
del other_state_variables
metrics['learning_rate'] = metrics_lib.Sum.from_model_output(learning_rate)
optimizer = jax.tree_map(lambda x, g: x + g, optimizer, grad_accum)
optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum)
return optimizer, metrics


Expand All @@ -74,7 +74,7 @@ def setUp(self):
self.init_train_state = train_state_lib.FlaxOptimTrainState(
self.init_optimizer
)
train_state_axes = jax.tree_map(lambda x: None, self.init_train_state)
train_state_axes = jax.tree.map(lambda x: None, self.init_train_state)
model_dir = self.create_tempdir().full_path

mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]}
Expand Down Expand Up @@ -138,7 +138,7 @@ def _test_train(self, precompile, mock_time=None):
expected_train_state = train_state_lib.FlaxOptimTrainState(
expected_optimizer
)
jax.tree_map(
jax.tree.map(
np.testing.assert_allclose, trainer.train_state, expected_train_state
)

Expand Down
2 changes: 1 addition & 1 deletion t5x/contrib/moe/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def scale_sharded_grads(


def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):
"""Like jax.tree_map but with a filter on the leaf path name.
"""Like jax.tree.map but with a filter on the leaf path name.
Args:
f: The function to be applied to each parameter in `param_tree`.
Expand Down
2 changes: 1 addition & 1 deletion t5x/contrib/moe/training_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_scale_sharded_grads(self):
'regular_layer': jnp.ones((1, 2)),
}
})
jax.tree_map(
jax.tree.map(
functools.partial(np.testing.assert_allclose, rtol=3e-7),
scaled_grads,
expected_grads,
Expand Down
6 changes: 3 additions & 3 deletions t5x/examples/decoder_only/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def test_mlp_same_out_dim(self):
)
params = module.init(random.PRNGKey(0), inputs, deterministic=True)
self.assertEqual(
jax.tree_map(lambda a: a.tolist(), params),
jax.tree.map(lambda a: a.tolist(), params),
{
'params': {
'wi': {
Expand Down Expand Up @@ -682,7 +682,7 @@ def test_relative_attention_bidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down Expand Up @@ -718,7 +718,7 @@ def test_relative_attention_unidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down
6 changes: 3 additions & 3 deletions t5x/examples/scalable_t5/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_mlp_same_out_dim(self):
)
params = module.init(random.PRNGKey(0), inputs, deterministic=True)
self.assertEqual(
jax.tree_map(lambda a: a.tolist(), params),
jax.tree.map(lambda a: a.tolist(), params),
{
'params': {
'wi': {
Expand Down Expand Up @@ -629,7 +629,7 @@ def test_relative_attention_bidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down Expand Up @@ -665,7 +665,7 @@ def test_relative_attention_unidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down
6 changes: 3 additions & 3 deletions t5x/examples/t5/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_mlp_same_out_dim(self):
)
params = module.init(random.PRNGKey(0), inputs, deterministic=True)
self.assertEqual(
jax.tree_map(lambda a: a.tolist(), params),
jax.tree.map(lambda a: a.tolist(), params),
{
'params': {
'wi': {
Expand Down Expand Up @@ -629,7 +629,7 @@ def test_relative_attention_bidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down Expand Up @@ -665,7 +665,7 @@ def test_relative_attention_unidirectional_params(self):
params = self.relative_attention.init(
random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False
)
param_shapes = jax.tree_map(lambda x: x.shape, params)
param_shapes = jax.tree.map(lambda x: x.shape, params)
self.assertEqual(
param_shapes,
{
Expand Down

0 comments on commit 1938bf8

Please sign in to comment.