Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,19 @@ def __init__(
self.load_only = checkpoint_config.load_only

self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
ft_manager.manager
if ft_manager
and ft_manager.enabled
and checkpoint_config.enable_ft_dataloader_checkpoints
else None
)

if ft_manager and ft_manager.enabled and not self.ft_manager:
logger.warn(
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
)

if self.ft_manager:
optimizers.init_cache_state_dict()

Expand Down
22 changes: 22 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,28 @@ class Checkpoint:
enable: bool = False
"""Whether to enable checkpoint"""

enable_ft_dataloader_checkpoints: bool = True
"""
Warning: Disabling this can have fault tolerant replicas training
over the same data multiple times. Use it with caution if training
over the same data is acceptable.

Used to enable checkpointing the dataloader index for fault tolerant training with torchft.

Fault tolerant training stores data loader index in the checkpoints, so that training can resume
without going over the same batch twice.

If enabled, data loader state is checkpointed. Otherwise, replicas
will train over the same data multiple times, which can result in
overfitting.

The failed replcia will still recover other state e.g. model
parameters from other replcias.

Note, if regular checkpointing is enabled, we also checkpoint the
data loader state. But when not using fault tolerance, the entire training starts from scratch.
"""

folder: str = "checkpoint"
"""
The folder to store the checkpoints.
Expand Down
Loading