Skip to content

Commit

Permalink
Cleaning up existing code with no behaviour change. Inside init2winit…
Browse files Browse the repository at this point in the history
… checkpoint.py there is a function that does 2 things -> restore checkpoint and replicate params, optimizer state to devices.

PiperOrigin-RevId: 719070002
  • Loading branch information
sourabh2k15 authored and copybara-github committed Jan 24, 2025
1 parent 7d8c2a8 commit 0932d97
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 48 deletions.
34 changes: 14 additions & 20 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from absl import flags
from absl import logging
from flax.training import checkpoints as flax_checkpoints
from init2winit.dataset_lib import data_utils
import jax

FLAGS = flags.FLAGS
Expand All @@ -43,32 +42,27 @@ def load_pytree(pytree_file, orbax_checkpointer=None):
return None


def replicate_and_maybe_restore_checkpoint(
def maybe_restore_checkpoint(
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
mesh,
train_dir,
external_checkpoint_path=None,
orbax_checkpointer=None):
"""Replicates everything, and optionally restores from a checkpoint.
"""Optionally restores from a checkpoint.
The checkpoint logic is as follows: if there is a checkpoint in `train_dir`,
restore it. Else, if `external_checkpoint_path` is set, restore the
checkpoint found there. Else, don't restore any checkpoint, and just
return the passed-in optimizer_state, params, batch_stats, and
metrics_grabber.
This function is also responsible for replicating the optimizer_state, params,
batch_stats, and training_metrics_grabber across multiple devices.
Args:
unreplicated_optimizer_state: unreplicated optimizer state
unreplicated_params: unreplicated params
unreplicated_batch_stats: unreplicated batch stats
unreplicated_training_metrics_state: unreplicated metrics state
mesh: Mesh specification to use for sharding.
train_dir: (str) The training directory where we will look for a checkpoint.
external_checkpoint_path: (str) If this argument is set, then we will load
the external checkpoint stored there.
Expand Down Expand Up @@ -130,30 +124,30 @@ def replicate_and_maybe_restore_checkpoint(
# Handle failure to load from external_checkpoint_path.
if ckpt_to_return['global_step'] == -1:
return (
data_utils.shard_pytree(unreplicated_optimizer_state, mesh),
data_utils.shard_pytree(unreplicated_params, mesh),
data_utils.shard_pytree(unreplicated_batch_stats, mesh),
data_utils.shard_pytree(unreplicated_training_metrics_state, mesh),
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
0, # global_step
0, # sum_train_cost
0, # preemption_count
False) # is_restored
else: # Else, don't restore from any checkpoint.
return (
data_utils.shard_pytree(unreplicated_optimizer_state, mesh),
data_utils.shard_pytree(unreplicated_params, mesh),
data_utils.shard_pytree(unreplicated_batch_stats, mesh),
data_utils.shard_pytree(unreplicated_training_metrics_state, mesh),
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
0, # global_step
0, # sum_train_cost
0, # preemption_count
False) # is_restored

return (
data_utils.shard_pytree(ckpt_to_return['optimizer_state'], mesh),
data_utils.shard_pytree(ckpt_to_return['params'], mesh),
data_utils.shard_pytree(ckpt_to_return['batch_stats'], mesh),
data_utils.shard_pytree(ckpt_to_return['training_metrics_grabber'], mesh),
ckpt_to_return['optimizer_state'],
ckpt_to_return['params'],
ckpt_to_return['batch_stats'],
ckpt_to_return['training_metrics_grabber'],
ckpt_to_return['global_step'], # global_step
ckpt_to_return['sum_train_cost'],
ckpt_to_return['preemption_count'], # preemption_count
Expand Down
36 changes: 14 additions & 22 deletions init2winit/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from init2winit import checkpoint
from init2winit.model_lib import models
from init2winit.shared_test_utilities import pytree_equal
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.tree_util
import numpy as np
Expand Down Expand Up @@ -64,12 +63,6 @@ def setUp(self):
orbax_checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
self.params = init_dict['params']

mesh_shape = (jax.device_count(),)
self.mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()),
axis_names=('devices',),
)

def tearDown(self):
shutil.rmtree(self.test_dir)
super(CheckpointTest, self).tearDown()
Expand Down Expand Up @@ -154,36 +147,35 @@ def test_all_variables_restored(self):
max_to_keep=1)

(
(_, ret_state),
(_, ret_params),
(_, ret_batch_stats),
(_, ret_training_metrics),
ret_state,
ret_params,
ret_batch_stats,
ret_training_metrics,
ret_global_step,
ret_sum_train_cost,
ret_preemption_count,
ret_is_restored,
) = checkpoint.replicate_and_maybe_restore_checkpoint(
) = checkpoint.maybe_restore_checkpoint(
initial_optimizer_state,
initial_params,
initial_batch_stats,
initial_training_metrics,
self.mesh,
fresh_train_dir,
orbax_checkpointer=self.orbax_checkpointer,
)

assert pytree_equal(
jax.device_get(ret_state), saved_optimizer_state
ret_state, saved_optimizer_state
)
assert pytree_equal(
jax.device_get(ret_params), saved_params
ret_params, saved_params
)
assert pytree_equal(
jax.device_get(ret_batch_stats),
ret_batch_stats,
saved_batch_stats,
)
assert pytree_equal(
jax.device_get(ret_training_metrics),
ret_training_metrics,
saved_training_metrics,
)
self.assertEqual(ret_sum_train_cost, sum_train_cost)
Expand All @@ -193,7 +185,7 @@ def test_all_variables_restored(self):

shutil.rmtree(fresh_train_dir)

def test_replicate_and_maybe_restore_from_checkpoint_logic(self):
def test_maybe_restore_from_checkpoint_logic(self):
"""Test that the right checkpoint is returned.
1. If no external_checkpoint_path was passed, and if there is no
Expand Down Expand Up @@ -238,13 +230,13 @@ def save_checkpoint(train_dir, global_step, preemption_count,
def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path):
"""Helper function to replicate_and_maybe_restore a checkpoint."""

(_, (_, ret_params), _, _,
(_, ret_params, _, _,
ret_global_step, ret_sum_train_cost, ret_preemption_count,
ret_is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint(
{}, params, {}, {}, self.mesh, train_dir, external_checkpoint_path,
ret_is_restored) = checkpoint.maybe_restore_checkpoint(
{}, params, {}, {}, train_dir, external_checkpoint_path,
orbax_checkpointer=self.orbax_checkpointer)

ret_params_unrep = jax.device_get(ret_params)
ret_params_unrep = ret_params

return (ret_params_unrep, ret_global_step, ret_sum_train_cost,
ret_preemption_count, ret_is_restored)
Expand Down
22 changes: 16 additions & 6 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from init2winit import checkpoint
from init2winit import schedules
from init2winit import utils
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import gradient_accumulator
from init2winit.optimizer_lib import optimizers
from init2winit.trainer_lib import trainer_utils
Expand Down Expand Up @@ -300,24 +301,33 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn):
unreplicated_batch_stats)

(
(optimizer_state_sharding, optimizer_state),
(params_sharding, params),
(batch_stats_sharding, batch_stats),
(metrics_state_sharding, metrics_state),
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_metrics_state,
global_step,
sum_train_cost,
preemption_count,
is_restored,
) = checkpoint.replicate_and_maybe_restore_checkpoint(
) = checkpoint.maybe_restore_checkpoint(
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_metrics_state,
self._mesh,
train_dir=self._train_dir,
external_checkpoint_path=self._external_checkpoint_path,
orbax_checkpointer=self._orbax_checkpointer,
)

optimizer_state_sharding, optimizer_state = data_utils.shard_pytree(
unreplicated_optimizer_state, self._mesh)
params_sharding, params = data_utils.shard_pytree(
unreplicated_params, self._mesh)
batch_stats_sharding, batch_stats = data_utils.shard_pytree(
unreplicated_batch_stats, self._mesh)
metrics_state_sharding, metrics_state = data_utils.shard_pytree(
unreplicated_metrics_state, self._mesh)

if is_restored:
preemption_count += 1

Expand Down

0 comments on commit 0932d97

Please sign in to comment.