-
Notifications
You must be signed in to change notification settings - Fork 30.9k
feat(trainer): Just-in-time (JIT) asynchronous checkpointing using SIGTERM signals #41723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@SunMarc @Rocketknight1 could you please review this PR. Thank you. |
@SunMarc this is definitely a pure or nearly-pure code agent PR, so leaving it up to you to decide if you want it in |
@SunMarc AI was indeed used to to write tests, for this PR description, and so on. I'm wondering if there is a policy against using AI whatsover in the transformers project. I've manually tested code with multiple different models to verify JIT checkpointing works on sigterm signals for over a month now. I'm willing take reviews and fix them. I've tried implement this without modifying the trainer code too much by keeping it separate and it's flag enabled. I would appreciate it if you could review it please. Thank you. |
@efazal, Sylvain and I are no longer with HF. In general the PR looks great and +1 (for whatever it counts) to add this functionality. Let's discuss the implementation. Question: I had similar functionality implemented in other frameworks I worked on, but I was using a check between trainer steps. Otherwise how do you know that when you trigger the weight saving you're not in the middle of a weight update and you could end up with inconsistent weights? The way I did it is via a flag - sigterm -> raise a flag -> trainer checks the flag as soon as step is finished -> goes into checkpoint saving -> exit - no threads are needed and the weights are consistent. I see no benefit for doing async saving here also because the trainer can't continue anyway, so the last step's results will be lost anyway if sigterm caught it in the middle. |
Thanks for your review @sfc-gh-sbekman 🙏, and apologies I didn't realize you were not with HF anymore. You made a great point about when a checkpoint gets triggered during weight update (optimizer pass). To avoid checkpointing during optimizer pass, I'm creating a new cuda stream, and I wait for the main cuda stream operations to complete, and after weight updates completed, checkpointing happens async in a new cuda stream. Again, you rightly said that async checkpointing is unnecessary given that the training is going to stop anyway. I just wanted to bring this approach forward to possibly extend async checkpointing to existing periodic checkpointing, since its currently synchronous and blocking GPU while periodic checkpoint takes place making trainings roughly 30% more expensive (in periodic checkpoint enabled trainings). But I think this approach is no longer needed since Pytorch addressed this problem recently and added support to asynchronously checkpoint and also added support to save model in safetensors format. So I will update the PR to remove async checkpointing. I like the approach of doing this at the end of step as you suggested. In the current implementation, the |
I think your async checkpointing as a replacement for blocking normal operation checkpointing would be a fantastic addition - perhaps in a separate PR? re: k8s - there should be a config setting in your job definition which can tell the scheduler how much grace time your job needs to shutdown. So the bigger the model the more time you usually need, and 30 secs will not be enough in most cases since you need to have enough time for 1 full step + saving. I recommend documenting how this can be done. SLURM jobs definitely have this feature - it's covered in detail here: https://github.com/stas00/ml-engineering/tree/11111d63c92edf3c10c25515851083de75cac629/training/fault-tolerance#dealing-with-forced-job-preemption Let's add the k8s graceful setup at that section? We can work together on it if you PR in. Additionally, when using such panic event-based saving techniques it's important to also ensure the checkpoint was fully saved as it could get the carpet yanked under 1sec too soon and the user won't be the wiser why their can't resume. Typically I use some sort of flag file which I write when the checkpoint has finished writing, so when resuming it'd first check that the flag file is there. or this can be done in reverse, touching "checkpoint-is-incomplete.txt" file and then deleting it when the checkpoint writing has finished, which perhaps would integrate better, be more obvious to the user inspecting its files. |
Yes, absolutely. I would contribute non blocking checkpointing separately after this PR goes through. I think that makes sense. For this feature to work consistently with larger models, the default 30 sec time period wouldn't suffice. And yes, there is a config which can propagate the value of graceful shutdown period to the pods I'm interested in raising a PR to your ml-engineering repo with regards to k8s graceful shown. I shall ping you when I do so. Also, another valid problem you addressed. I did in fact run into incomplete checkpoint files when trying to resume training when I tested with larger models. I addressed this problem by performing a check within my training function to check if the lastest checkpoint contains the files: model.safetensors, optimizer.pt, scheduler.pt, and rng.states.pth. To confirm if the checkpoint is valid, otherwise go back to a earlier checkpoint file to resume or start over. But I find your idea brilliant to do it in reverse and I shall do that. And perhaps later, as part of a separate PR we could leverage this file to determine if the training can be resumed with latest checkpoint, or defer to earlier checkpoint if available. Thanks a mil for the great ideas. I shall implement them soon. |
Glad to hear you found my suggestions useful, Esa
this is insufficient since the file could exist but be incomplete if the writing got aborted mid-way. But as you said you will switch to a sentinel control file approach. You might like the other solution I came up with when training BLOOM-176B - the kill and save switches:
if you have those in place, you don't even need a flag on sigterm - now all you need is to drop a save switch file, but either way it works. The variable flag is likely to be a tad faster. Looking forward to your PRs wrt k8s recipes, I'm a newbie there so need more time to gain expertise before I can share recipes with others. I like SLURM so much better. The k8s' level of complexity is insane compared to SLURM. You need one sysadmin to manage a SLURM cluster, you need a whole team to deal with k8s. |
…ting. Introduce sentinal file to identify incomplete checkpoints. Update trainer arg doc and tests.
a457b74
to
05411f3
Compare
That's true, definitely not a sustainable approach. I've implemented it as you suggested and it's a great approach. Yes, I did read up the fault-tolerance doc, it was certainly a very interesting read. Yeah, I suppose we are trying to achieve the kill switch approach via a flag and yes, it's a bit quicker since we can trigger checkpoint in many callbacks. I've updated the PR now based on your comments so far. I still need to manually verify this which I will do next and confirm. Again, thanks for all the info you've provided thus far @sfc-gh-sbekman . Please have a look at the PR and let me know if you have any concerns in particular. I've updated the docs to mention about grace time period in both in K8s and slurp orchestrators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one small doc suggestion, otherwise looks good.
should these be more explicit and have the trainer prefix? In the context of the code it's obvious what Oddly I don't see the trainer using some consistent prefix for its env vars, may be I missed something - ideally those should start with |
…elated envs with HF_ prefix.
Good catch on the env prefix concern. Upon looking around in the repo, I could see I did a bunch of manual testing by finetuning a llama3.2 8B instruct model on multinode setting. Works pretty well as before, expect now its synchronous. I also tested JIT checkpointing by coinciding it with periodic checkpoint to see how it behaves if JIT checkpoint was triggered before and during periodic. Following were the result:
|
@SunMarc could you please add other reviewers to this PR please. Would also appreciate your review. Thank you! |
What does this PR do?
This PR adds Just-In-Time (JIT) asynchronous checkpointing to the Trainer API, enabling graceful handling of training interruptions in preemptible workloads. When enabled via
enable_jit_checkpoint=True
, the trainerautomatically saves progress upon receiving a SIGTERM signal, preventing loss of training progress in shared cluster environments (e.g., Kueue).
Motivation
Modern ML training increasingly relies on orchestration platforms like Kubeflow and workload managers like Kueue in shared Kubernetes clusters. These platforms require the ability to:
Currently, if a training job is terminated before the next periodic checkpoint, all progress since the last checkpoint is lost. This wastes significant compute resources and training time.
Implementation
New Components:
-
trainer_jit_checkpoint.py
: Core implementation withCheckpointManager
andJITCheckpointCallback
- Registers SIGTERM signal handler to intercept termination signals
- Implements 3-second kill-wait period to distinguish SIGTERM from SIGKILL
- Thread-based checkpoint execution to avoid blocking training loop
Modified Components:
trainer.py
: Automatically adds JITCheckpointCallback whenenable_jit_checkpoint=True
training_args.py
: Newenable_jit_checkpoint
boolean flag (configurable viaENABLE_JIT_CHECKPOINT
environment variable)Environment Variable Support Rationale:
The new environment variables (
ENABLE_JIT_CHECKPOINT
,SAVE_STRATEGY
,SAVE_STEPS
,SAVE_TOTAL_LIMIT
,OUTPUT_DIR
,RESUME_FROM_CHECKPOINT
) allow orchestration platforms like Kubeflow to automatically configure checkpointing behaviorwithout requiring users to modify their training scripts. This enables:
Key Features:
Testing
Comprehensive unit tests added in
tests/trainer/test_trainer_jit_checkpoint.py
covering:Fixes #38961
Who can review?
@SunMarc - This PR adds emergency checkpointing functionality to the Trainer class
Before submitting
Pull Request section?
to it if that's the case.
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.