Skip to content

Commit

Permalink
Add relative_steps to train interface.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613573339
  • Loading branch information
T5X Team authored and t5-copybara committed Mar 7, 2024
1 parent 20e3e03 commit ecc6369
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
5 changes: 5 additions & 0 deletions t5x/configs/runs/finetune.gin
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
#
# Otherwise, use TRAIN_STEPS_RELATIVE to specify the number of additional
# training steps to perform on top of the initial checkpoint.
#
# Commonly overridden options:
# - DROPOUT_RATE
# - BATCH_SIZE
Expand Down Expand Up @@ -61,6 +64,7 @@ JSON_WRITE_N_RESULTS = None # Write all inferences.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None
TRAIN_STEPS_RELATIVE = None

# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None
Expand All @@ -77,6 +81,7 @@ train_script.train:
total_steps = %TRAIN_STEPS
eval_steps = %EVAL_STEPS
eval_period = %EVAL_PERIOD
relative_steps = %TRAIN_STEPS_RELATIVE
random_seed = %RANDOM_SEED
use_hardware_rng = %USE_HARDWARE_RNG
summarize_config_fn = @gin_utils.summarize_gin_config
Expand Down
2 changes: 2 additions & 0 deletions t5x/configs/runs/pretrain.gin
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ SHUFFLE_TRAIN_EXAMPLES = True
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None
TRAIN_STEPS_RELATIVE = None

# Can be overridden with `train.*`.`
train_script.train:
Expand All @@ -56,6 +57,7 @@ train_script.train:
total_steps = %TRAIN_STEPS
eval_steps = 20
eval_period = 1000
relative_steps = %TRAIN_STEPS_RELATIVE
random_seed = %RANDOM_SEED
use_hardware_rng = %USE_HARDWARE_RNG
summarize_config_fn = @gin_utils.summarize_gin_config
Expand Down
2 changes: 2 additions & 0 deletions t5x/fiddle_configs/configs/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def train(
task_feature_lengths: Mapping[str, int],
eval_steps: int = EVAL_STEPS,
eval_period: int = EVAL_PERIOD,
relative_steps: Optional[int] = None,
random_seed: Optional[int] = RANDOM_SEED,
mixture_or_task_module: Optional[str] = None,
use_hardware_rng: bool = USE_HARDWARE_RNG,
Expand Down Expand Up @@ -119,6 +120,7 @@ def train(
total_steps=train_steps,
eval_steps=eval_steps,
eval_period=eval_period,
relative_steps=relative_steps,
random_seed=random_seed,
use_hardware_rng=use_hardware_rng,
summarize_config_fn=config_utils.summarize_fiddle_config,
Expand Down
2 changes: 2 additions & 0 deletions t5x/fiddle_configs/configs/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def train(
task_feature_lengths: Mapping[str, int],
eval_steps: int = EVAL_STEPS,
eval_period: int = EVAL_PERIOD,
relative_steps: Optional[int] = None,
random_seed: Optional[int] = RANDOM_SEED,
mixture_or_task_module: Optional[str] = None,
use_hardware_rng: bool = USE_HARDWARE_RNG,
Expand Down Expand Up @@ -103,6 +104,7 @@ def train(
total_steps=train_steps,
eval_steps=eval_steps,
eval_period=eval_period,
relative_steps=relative_steps,
random_seed=random_seed,
use_hardware_rng=use_hardware_rng,
summarize_config_fn=config_utils.summarize_fiddle_config,
Expand Down
48 changes: 27 additions & 21 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train(
total_steps: int,
eval_steps: int,
eval_period: int,
relative_steps: Optional[int] = None,
stats_period: Optional[int] = None,
random_seed: Optional[int],
use_hardware_rng: bool = False,
Expand Down Expand Up @@ -154,6 +155,8 @@ def train(
eval_steps: The number of batches to process for each train-eval loop.
eval_period: The number of train steps between each evaluation (both
train-eval and infer-eval).
relative_steps: The number of train steps to take relative to the current
step loaded from the checkpoint. If this is set, total_steps is ignored.
stats_period: The number of train steps between writing scalar stats. If
None, defaults to eval_period.
random_seed: A random seed to use for dropout and initialization. If None, a
Expand Down Expand Up @@ -214,27 +217,6 @@ def train(
checkpoint_cfg.save.checkpoint_steps if checkpoint_cfg.save else []
)

if eval_period or checkpoint_period or gc_period:
steps_per_epoch = min(
eval_period or np.inf, checkpoint_period or np.inf, gc_period or np.inf
)
else:
steps_per_epoch = total_steps
stats_period = stats_period or steps_per_epoch
if (
eval_period
and eval_period % steps_per_epoch
or checkpoint_period
and checkpoint_period % steps_per_epoch
or gc_period
and gc_period % steps_per_epoch
):
raise ValueError(
f'Checkpoint period ({checkpoint_period}), eval '
f'period ({eval_period}), and GC period ({gc_period}) must all be '
'multiples of each other.'
)

if use_hardware_rng or random_seed is None:
logging.info(
'Using fast RngBitGenerator PRNG for initialization and dropout.'
Expand Down Expand Up @@ -420,6 +402,30 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):
# Restore step from last checkpoint or set to 0 if training from scratch.
host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error

if relative_steps:
total_steps = host_step + relative_steps

if eval_period or checkpoint_period or gc_period:
steps_per_epoch = min(
eval_period or np.inf, checkpoint_period or np.inf, gc_period or np.inf
)
else:
steps_per_epoch = total_steps
stats_period = stats_period or steps_per_epoch
if (
eval_period
and eval_period % steps_per_epoch
or checkpoint_period
and checkpoint_period % steps_per_epoch
or gc_period
and gc_period % steps_per_epoch
):
raise ValueError(
f'Checkpoint period ({checkpoint_period}), eval '
f'period ({eval_period}), and GC period ({gc_period}) must all be '
'multiples of each other.'
)

# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
Expand Down

0 comments on commit ecc6369

Please sign in to comment.