Skip to content

Commit 15ed53d

Browse files
Fixes LoRA SDXL training script with DDP + PEFT (#6816)
Update train_dreambooth_lora_sdxl.py
1 parent 9cc59ba commit 15ed53d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,8 +1399,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
13991399
text_encoder_two.train()
14001400

14011401
# set top parameter requires_grad = True for gradient checkpointing works
1402-
text_encoder_one.text_model.embeddings.requires_grad_(True)
1403-
text_encoder_two.text_model.embeddings.requires_grad_(True)
1402+
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1403+
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
14041404

14051405
for step, batch in enumerate(train_dataloader):
14061406
with accelerator.accumulate(unet):

0 commit comments

Comments
 (0)