-
Notifications
You must be signed in to change notification settings - Fork 475
Adding logic for cleaning up FT checkpoints #1528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR seems to be submitted to a very old branch.
We've removed JobConfig
from CheckpointManager
signature to make it a useful util in general. Please rebase and respect that.
torchtitan/config_manager.py
Outdated
@@ -680,6 +680,11 @@ class FaultTolerance: | |||
This is only used when "semi_sync_method" is set. | |||
""" | |||
|
|||
checkpoint_keep_latest_k: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can't you use checkpoint.keep_latest_k
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When torchft is enabled there are two types of checkpoints:
- full checkpoint (also without torchft this can be enabled)
- per-replica checkpoint (specific for torchft per-step fault tolerance)
I believe this change should only affect 2, so it makes sense to keep it under the FaultTolerance dataclass
torchtitan/components/checkpoint.py
Outdated
@@ -112,14 +112,19 @@ def purge_thread(purge_queue: queue.Queue): | |||
if isinstance(path, Terminate): | |||
return | |||
assert isinstance(path, str) | |||
logger.info("Checkpointer is deleting %s.", path) | |||
|
|||
if not 'ft-replica' in path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can "ft-replica" be a variable instead and use that across the checks? Also it looks like there is a mispelling for the folder name since its "replicat" currently lol, can you update that?
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can do this
@@ -641,6 +647,7 @@ def _ft_save(self, step: int) -> None: | |||
self.save_future = self.dcp_save( | |||
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC | |||
) | |||
self._purge_stale_ft_checkpoints() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to call this here? I thought the purge thread do the deletion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but the directories to purge need to be added to the queue.
c19edff
to
2818959
Compare
When using semi-sync training, FT checkpoints can start to take up a considerable amount of storage and there is currectly no mechanism to clean them up. This PR adds:
For this initial PR, I made the decision to disable logging for this deletion as it creates too much output, but this is up for discussion.