diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index b5e09309b..bc5971f6f 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -69,7 +69,9 @@ LazyArray = checkpoint_importer.LazyArray LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray -Dataset = Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] +Dataset = Union[ + tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator, None +] # Version 3 is used since 2021-06-10, compared to version 2 the only change is # that `bfloat16` arrays are written in Tensorstore using its native `bfloat16` @@ -2019,8 +2021,9 @@ class _OrbaxParamInfo: class DatasetCheckpointHandler(ocp.CheckpointHandler): """A CheckpointHandler implementation that handles tf.data.Iterator.""" - def __init__(self, checkpoint_filename: str): + def __init__(self, checkpoint_filename: str, should_write_dataset_ckpt: bool): self._checkpoint_filename = checkpoint_filename + self._should_write_dataset_ckpt = should_write_dataset_ckpt def save( self, @@ -2033,18 +2036,18 @@ def save( directory: save location directory. args: DatasetArgs (see below). """ - item = args.item - if item is None: - raise ValueError('Must provide item to save.') - if jax.process_count() > 1: - directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' - directory.mkdir(parents=False, exist_ok=False) - if isinstance(item, tf.data.Iterator): - ckpt = tf.train.Checkpoint(ds=item) - ckpt.write(os.fspath(directory / self._checkpoint_filename)) - elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): - item.save(os.fspath(directory / self._checkpoint_filename)) - multihost_utils.sync_global_devices('DatasetCheckpointHandler:save') + if self._should_write_dataset_ckpt: + item = args.item + if item is None: + raise ValueError('Must provide item to save.') + if jax.process_count() > 1: + directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' + directory.mkdir(parents=False, exist_ok=False) + if isinstance(item, tf.data.Iterator): + ckpt = tf.train.Checkpoint(ds=item) + ckpt.write(os.fspath(directory / self._checkpoint_filename)) + elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): + item.save(os.fspath(directory / self._checkpoint_filename)) def restore( self, @@ -2060,19 +2063,20 @@ def restore( Returns: a tf.data.Iterator restored from `directory`. """ - if args is None: - raise ValueError('Must provide args to restore.') - item = args.item - if jax.process_count() > 1: - directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' - if isinstance(item, tf.data.Iterator): - ckpt = tf.train.Checkpoint(ds=item) - ckpt.read( - os.fspath(directory / self._checkpoint_filename) - ).assert_consumed() - elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): - item.load(os.fspath(directory / self._checkpoint_filename)) - return item + if self._should_write_dataset_ckpt: + if args is None: + raise ValueError('Must provide args to restore.') + item = args.item + if jax.process_count() > 1: + directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' + if isinstance(item, tf.data.Iterator): + ckpt = tf.train.Checkpoint(ds=item) + ckpt.read( + os.fspath(directory / self._checkpoint_filename) + ).assert_consumed() + elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): + item.load(os.fspath(directory / self._checkpoint_filename)) + return item @ocp.args.register_with_handler( @@ -2323,11 +2327,13 @@ def __init__( ) # TODO(b/273803615) Enable OCDBT. self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=False) - item_handlers = {_STATE_KEY: self._state_handler} - if self._should_write_dataset_ckpt: - item_handlers[_DATASET_KEY] = DatasetCheckpointHandler( - checkpoint_filename=dataset_ckpt_name - ) + item_handlers = { + _STATE_KEY: self._state_handler, + _DATASET_KEY: DatasetCheckpointHandler( + checkpoint_filename=dataset_ckpt_name, + should_write_dataset_ckpt=self._should_write_dataset_ckpt, + ), + } def best_fn(metrics): return metrics[metric_name_to_monitor] @@ -2425,9 +2431,8 @@ def save( state_dict, save_args=save_args, ), + _DATASET_KEY: DatasetArgs(self._dataset_iterator), } - if self._should_write_dataset_ckpt: - args[_DATASET_KEY] = DatasetArgs(self._dataset_iterator) args = ocp.args.Composite(**args) saved = self._manager.save(step, args=args, force=force) diff --git a/t5x/test_utils.py b/t5x/test_utils.py index 960f240bd..a969a6faf 100644 --- a/t5x/test_utils.py +++ b/t5x/test_utils.py @@ -30,6 +30,7 @@ import seqio import t5.data from t5x import adafactor +from t5x import checkpoints from t5x import models from t5x import optimizers from t5x import partitioning @@ -409,3 +410,51 @@ def partition( def compile(self, partitioned_fn, *args): return None + +# -------------------- Checkpoint helpers -------------------- + + +def _train_state_shapes(train_state): + def _maybe_get(x): + if isinstance(x, LazyArray): + return x.get() + return x + + train_state = jax.tree_util.tree_map(_maybe_get, train_state) + return jax.eval_shape(lambda: train_state) + + +def save(checkpointer_or_manager, train_state, force=False): + saved = checkpointer_or_manager.save(train_state, force=force) + checkpointer_or_manager.wait_until_finished() + return saved + + +def create_checkpointer_or_manager( + train_state_shapes, + partitioner, + directory, + dataset_iterator=None, + save_dtype=None, + restore_dtype=None, + best=False, + keep=None, + period=1, + checkpoint_steps=None, + keep_checkpoints_without_metrics=True, +): + """Creates an Orbax CheckpointManagerInterface.""" + metric_name_to_monitor = 'train/accuracy' if best else None + return checkpoints.OrbaxCheckpointManagerInterface( + directory, + train_state_shapes, + partitioner, + dataset_iterator=dataset_iterator, + save_dtype=save_dtype, + restore_dtype=restore_dtype, + keep=keep, + period=period, + checkpoint_steps=checkpoint_steps, + metric_name_to_monitor=metric_name_to_monitor, + keep_checkpoints_without_metrics=keep_checkpoints_without_metrics, + )