Skip to content

Commit 170833c

Browse files
[Fix] fp16 unscaling in train_dreambooth_lora_sdxl (#10889)
Fix fp16 bug Co-authored-by: Sayak Paul <[email protected]>
1 parent db21c97 commit 170833c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def log_validation(
203203

204204
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
205205

206-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
206+
pipeline = pipeline.to(accelerator.device)
207207
pipeline.set_progress_bar_config(disable=True)
208208

209209
# run inference
@@ -213,7 +213,7 @@ def log_validation(
213213
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
214214
autocast_ctx = nullcontext()
215215
else:
216-
autocast_ctx = torch.autocast(accelerator.device.type)
216+
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
217217

218218
with autocast_ctx:
219219
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

0 commit comments

Comments
 (0)