Model Training Fails at “Loading checkpoint shards” with SIGKILL (Out of Memory?) on RTX 5090 #12278
Unanswered
tianyu1997
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I’m trying to train a LoRA model using the train_dreambooth_lora_flux.py script from the latest diffusers repo with the model black-forest-labs/FLUX.1-dev on an NVIDIA RTX 5090. However, the process is killed (SIGKILL: 9) right at the “Loading checkpoint shards” step, before any debug print statements in the script are reached.
PyTorch version: 2.4.1+cu121 (nightly, supports CUDA 12.1)
GPU: NVIDIA GeForce RTX 5090 (sm_120)
accelerate launch train_dreambooth_lora_flux.py
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-dev
--instance_data_dir=dog
--output_dir=trained-flux-lora
--mixed_precision=bf16
--instance_prompt="a photo of sks dog"
--resolution=512
--train_batch_size=1
--guidance_scale=1
--gradient_accumulation_steps=4
--optimizer="prodigy"
--learning_rate=1.
--report_to="wandb"
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=500
--validation_prompt="A photo of sks dog in a bucket"
--validation_epochs=25
--seed="0"
--push_to_hub
What I’ve tried:
Lowering resolution and batch size
Using fp16 instead of bf16
Ensuring no other processes are using GPU memory
Upgrading to the latest nightly PyTorch
Adding debug prints: the process is killed before model loading completes
Error message:
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
... (then process is killed with SIGKILL: 9)
PyTorch warning:
NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_70 sm_75 sm_80 sm_86 sm_90.
Questions:
Is this a memory issue, or is it due to lack of support for sm_120 in current PyTorch builds?
Has anyone successfully run large models on the RTX 5090 with PyTorch?
Any workarounds or advice for getting this to work?
Thanks in advance for any help!
Beta Was this translation helpful? Give feedback.
All reactions