Skip to content

Commit

Permalink
Fixes LoRA SDXL training script with DDP + PEFT (#6816)
Browse files Browse the repository at this point in the history
Update train_dreambooth_lora_sdxl.py
  • Loading branch information
younesbelkada authored Feb 3, 2024
1 parent 9cc59ba commit 15ed53d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_two.train()

# set top parameter requires_grad = True for gradient checkpointing works
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
Expand Down

0 comments on commit 15ed53d

Please sign in to comment.