From 15ed53d27227c215ebd3e36fab7b537f23b4b105 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Sat, 3 Feb 2024 05:16:32 +0100 Subject: [PATCH] Fixes LoRA SDXL training script with DDP + PEFT (#6816) Update train_dreambooth_lora_sdxl.py --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a995eb3043dc..2cc2ab79db95 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1399,8 +1399,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder_one.text_model.embeddings.requires_grad_(True) - text_encoder_two.text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet):