Skip to content

Conversation

efazal
Copy link

@efazal efazal commented Oct 18, 2025

What does this PR do?

This PR adds Just-In-Time (JIT) asynchronous checkpointing to the Trainer API, enabling graceful handling of training interruptions in preemptible workloads. When enabled via enable_jit_checkpoint=True, the trainer
automatically saves progress upon receiving a SIGTERM signal, preventing loss of training progress in shared cluster environments (e.g., Kueue).

Motivation

Modern ML training increasingly relies on orchestration platforms like Kubeflow and workload managers like Kueue in shared Kubernetes clusters. These platforms require the ability to:

  1. Pause and Resume Jobs: Dynamically reallocate resources by pausing long-running training jobs and resuming them later without losing progress
  2. Kueue Preemption: Support preemptible workloads where lower-priority jobs can be terminated (via SIGTERM) to make room for higher-priority workloads
  3. Dynamic Scaling: Enable training jobs to scale down/up based on cluster resource availability

Currently, if a training job is terminated before the next periodic checkpoint, all progress since the last checkpoint is lost. This wastes significant compute resources and training time.

Implementation

New Components:
- trainer_jit_checkpoint.py: Core implementation with CheckpointManager and JITCheckpointCallback
- Registers SIGTERM signal handler to intercept termination signals
- Implements 3-second kill-wait period to distinguish SIGTERM from SIGKILL
- Thread-based checkpoint execution to avoid blocking training loop

Modified Components:

  • trainer.py: Automatically adds JITCheckpointCallback when enable_jit_checkpoint=True
  • training_args.py: New enable_jit_checkpoint boolean flag (configurable via ENABLE_JIT_CHECKPOINT environment variable)

Environment Variable Support Rationale:
The new environment variables (ENABLE_JIT_CHECKPOINT, SAVE_STRATEGY, SAVE_STEPS, SAVE_TOTAL_LIMIT, OUTPUT_DIR, RESUME_FROM_CHECKPOINT) allow orchestration platforms like Kubeflow to automatically configure checkpointing behavior
without requiring users to modify their training scripts. This enables:

  • Centralized checkpoint configuration through cluster operators
  • Dynamic checkpoint policies based on workload type
  • Seamless integration with existing training code

Key Features:

  • Graceful training termination after checkpoint completion
  • Environment variable support for easier configuration in cluster environments
  • Compatible with existing checkpoint management (save_steps, save_strategy, etc.)

Testing

Comprehensive unit tests added in tests/trainer/test_trainer_jit_checkpoint.py covering:

  • CheckpointManager initialization and configuration
  • SIGTERM signal handler flow and duplicate signal handling
  • Callback integration with Trainer lifecycle hooks

Fixes #38961

Who can review?

@SunMarc - This PR adds emergency checkpointing functionality to the Trainer class

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the issue? Please add a link
    to it if that's the case.
  • Updated Trainer args doc.
  • Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@efazal
Copy link
Author

efazal commented Oct 20, 2025

@SunMarc @Rocketknight1 could you please review this PR. Thank you.

@Rocketknight1
Copy link
Member

@SunMarc this is definitely a pure or nearly-pure code agent PR, so leaving it up to you to decide if you want it in Trainer; feel free to close if you don't

@efazal
Copy link
Author

efazal commented Oct 20, 2025

@SunMarc AI was indeed used to to write tests, for this PR description, and so on. I'm wondering if there is a policy against using AI whatsover in the transformers project. I've manually tested code with multiple different models to verify JIT checkpointing works on sigterm signals for over a month now. I'm willing take reviews and fix them. I've tried implement this without modifying the trainer code too much by keeping it separate and it's flag enabled. I would appreciate it if you could review it please. Thank you.

@efazal
Copy link
Author

efazal commented Oct 21, 2025

@sgugger @stas00 could you please review this PR. Thank you.

@sfc-gh-sbekman
Copy link

sfc-gh-sbekman commented Oct 21, 2025

@efazal, Sylvain and I are no longer with HF.

In general the PR looks great and +1 (for whatever it counts) to add this functionality. Let's discuss the implementation.

Question: I had similar functionality implemented in other frameworks I worked on, but I was using a check between trainer steps.

Otherwise how do you know that when you trigger the weight saving you're not in the middle of a weight update and you could end up with inconsistent weights?

The way I did it is via a flag - sigterm -> raise a flag -> trainer checks the flag as soon as step is finished -> goes into checkpoint saving -> exit - no threads are needed and the weights are consistent.

I see no benefit for doing async saving here also because the trainer can't continue anyway, so the last step's results will be lost anyway if sigterm caught it in the middle.

@efazal
Copy link
Author

efazal commented Oct 21, 2025

Thanks for your review @sfc-gh-sbekman 🙏, and apologies I didn't realize you were not with HF anymore.

You made a great point about when a checkpoint gets triggered during weight update (optimizer pass). To avoid checkpointing during optimizer pass, I'm creating a new cuda stream, and I wait for the main cuda stream operations to complete, and after weight updates completed, checkpointing happens async in a new cuda stream.

Again, you rightly said that async checkpointing is unnecessary given that the training is going to stop anyway. I just wanted to bring this approach forward to possibly extend async checkpointing to existing periodic checkpointing, since its currently synchronous and blocking GPU while periodic checkpoint takes place making trainings roughly 30% more expensive (in periodic checkpoint enabled trainings). But I think this approach is no longer needed since Pytorch addressed this problem recently and added support to asynchronously checkpoint and also added support to save model in safetensors format. So I will update the PR to remove async checkpointing.

I like the approach of doing this at the end of step as you suggested. In the current implementation, the _sigterm_handler() immediately triggers checkpointing. Instead, I will mark the flag checkpoint_requested to True and as part of the callbacks on_step_end(), on_step_begin, and on_pre_optimizer_step() I will trigger checkpoint based on the flag. I'm planning to trigger in 3 different callbacks, since we are targeting graceful shutdowns, and in Kubernetes the default is 30 seconds graceful shutdown period so we need to checkpoint asap and the time it takes to checkpoint varies based on the model being trained/fine-tuned. The 30 second default can we overridden, but I doubt most users will consciously do this. I hope this makes sense.

@stas00
Copy link
Contributor

stas00 commented Oct 21, 2025

I think your async checkpointing as a replacement for blocking normal operation checkpointing would be a fantastic addition - perhaps in a separate PR?

re: k8s - there should be a config setting in your job definition which can tell the scheduler how much grace time your job needs to shutdown. So the bigger the model the more time you usually need, and 30 secs will not be enough in most cases since you need to have enough time for 1 full step + saving. I recommend documenting how this can be done.

SLURM jobs definitely have this feature - it's covered in detail here: https://github.com/stas00/ml-engineering/tree/11111d63c92edf3c10c25515851083de75cac629/training/fault-tolerance#dealing-with-forced-job-preemption Let's add the k8s graceful setup at that section? We can work together on it if you PR in.

Additionally, when using such panic event-based saving techniques it's important to also ensure the checkpoint was fully saved as it could get the carpet yanked under 1sec too soon and the user won't be the wiser why their can't resume. Typically I use some sort of flag file which I write when the checkpoint has finished writing, so when resuming it'd first check that the flag file is there. or this can be done in reverse, touching "checkpoint-is-incomplete.txt" file and then deleting it when the checkpoint writing has finished, which perhaps would integrate better, be more obvious to the user inspecting its files.

@efazal
Copy link
Author

efazal commented Oct 21, 2025

Yes, absolutely. I would contribute non blocking checkpointing separately after this PR goes through.

I think that makes sense. For this feature to work consistently with larger models, the default 30 sec time period wouldn't suffice. And yes, there is a config which can propagate the value of graceful shutdown period to the pods terminationGracePeriodSeconds field. In most cloud-native based trainers like Kubeflow for example allows to do this done via CRD job pod template.spec. I shall document this in the trainer arguments file for the arg enable_jit_checkpoint. Might take inspiration from your ml-engineering repo, to let the user know how to decide the amount of graceful shutdown time period.

I'm interested in raising a PR to your ml-engineering repo with regards to k8s graceful shown. I shall ping you when I do so.

Also, another valid problem you addressed. I did in fact run into incomplete checkpoint files when trying to resume training when I tested with larger models. I addressed this problem by performing a check within my training function to check if the lastest checkpoint contains the files: model.safetensors, optimizer.pt, scheduler.pt, and rng.states.pth. To confirm if the checkpoint is valid, otherwise go back to a earlier checkpoint file to resume or start over. But I find your idea brilliant to do it in reverse and I shall do that. And perhaps later, as part of a separate PR we could leverage this file to determine if the training can be resumed with latest checkpoint, or defer to earlier checkpoint if available.

Thanks a mil for the great ideas. I shall implement them soon.

@stas00
Copy link
Contributor

stas00 commented Oct 22, 2025

Glad to hear you found my suggestions useful, Esa

by performing a check within my training function to check if the lastest checkpoint contains the files: model.safetensors, optimizer.pt, scheduler.pt, and rng.states.pth.

this is insufficient since the file could exist but be incomplete if the writing got aborted mid-way. But as you said you will switch to a sentinel control file approach.

You might like the other solution I came up with when training BLOOM-176B - the kill and save switches:

if you have those in place, you don't even need a flag on sigterm - now all you need is to drop a save switch file, but either way it works. The variable flag is likely to be a tad faster.

Looking forward to your PRs wrt k8s recipes, I'm a newbie there so need more time to gain expertise before I can share recipes with others. I like SLURM so much better. The k8s' level of complexity is insane compared to SLURM. You need one sysadmin to manage a SLURM cluster, you need a whole team to deal with k8s.

…ting. Introduce sentinal file to identify incomplete checkpoints. Update trainer arg doc and tests.
@efazal efazal force-pushed the feat-jit-checkpointing branch from a457b74 to 05411f3 Compare October 22, 2025 16:50
@efazal
Copy link
Author

efazal commented Oct 22, 2025

this is insufficient since the file could exist but be incomplete if the writing got aborted mid-way. But as you said you will switch to a sentinel control file approach

That's true, definitely not a sustainable approach. I've implemented it as you suggested and it's a great approach.

Yes, I did read up the fault-tolerance doc, it was certainly a very interesting read. Yeah, I suppose we are trying to achieve the kill switch approach via a flag and yes, it's a bit quicker since we can trigger checkpoint in many callbacks.

I've updated the PR now based on your comments so far. I still need to manually verify this which I will do next and confirm.

Again, thanks for all the info you've provided thus far @sfc-gh-sbekman . Please have a look at the PR and let me know if you have any concerns in particular. I've updated the docs to mention about grace time period in both in K8s and slurp orchestrators.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one small doc suggestion, otherwise looks good.

@stas00
Copy link
Contributor

stas00 commented Oct 22, 2025

should these be more explicit and have the trainer prefix? SAVE_STRATEGY, SAVE_STEPS, SAVE_TOTAL_LIMIT

In the context of the code it's obvious what save_ implies, but as an env var this is too ambiguous. Perhaps SAVE_CHECKPOINT_*?

Oddly I don't see the trainer using some consistent prefix for its env vars, may be I missed something - ideally those should start with HF_TRAINER_ or some such not to pollute the env, but ask the current maintainers to which one to use. e.g. HF Accelerate env vars all start with ACCELERATE_.

@efazal
Copy link
Author

efazal commented Oct 22, 2025

Good catch on the env prefix concern. Upon looking around in the repo, I could see HF_ prefix used for HF_HOME and HF_TOKEN, and I figured since there are other trainers like TRL which is also based on HF Transformer trainer, I updated the checkpoint variables also likewise i.e. HF_ENABLE_JIT_CHECKPOINT, so even TRL users can benefit from this feature and it works across HF ecosystem I suppose. I hope it's a good idea. Maybe @SunMarc can correct me here perhaps.

I did a bunch of manual testing by finetuning a llama3.2 8B instruct model on multinode setting. Works pretty well as before, expect now its synchronous. I also tested JIT checkpointing by coinciding it with periodic checkpoint to see how it behaves if JIT checkpoint was triggered before and during periodic. Following were the result:

  • JIT checkpoint triggered right before periodic checkpoint: Since we set the control.should_save = False when JIT checkpoint is triggered, so it doesn't periodic checkpoint, and job is terminated which is what we expect.
  • JIT checkpoint triggered during periodic checkpoint: The periodic check completes and the job is terminated. But I suspect that if the graceful period was exessive, the training loop continues after the periodic checkpoint and the callback would trigger JIT checkpoint as well which may possibly lead to an incomplete checkpoint. WDYT? @sfc-gh-sbekman

@efazal
Copy link
Author

efazal commented Oct 23, 2025

@SunMarc could you please add other reviewers to this PR please. Would also appreciate your review. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🛠️ Add Failure-Safe Training with Emergency Checkpointing

4 participants