From 196aef5a6f76e1ad6ba889184860c3633d166910 Mon Sep 17 00:00:00 2001 From: Dimitri Barbot Date: Tue, 28 Jan 2025 14:46:41 +0100 Subject: [PATCH] Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670) Fix pipeline dtype unexpected change when using SDXL reference community pipelines --- .../community/stable_diffusion_xl_controlnet_reference.py | 8 +++++++- examples/community/stable_diffusion_xl_reference.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/community/stable_diffusion_xl_controlnet_reference.py b/examples/community/stable_diffusion_xl_controlnet_reference.py index ac3159e5e6e8..2c9bef311b0e 100644 --- a/examples/community/stable_diffusion_xl_controlnet_reference.py +++ b/examples/community/stable_diffusion_xl_controlnet_reference.py @@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device) - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: self.upcast_vae() refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if refimage.dtype != self.vae.dtype: @@ -223,6 +224,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + return ref_image_latents def prepare_ref_image( diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 6439280cb185..e01eac970b58 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -139,7 +139,8 @@ def retrieve_timesteps( class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device) - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: self.upcast_vae() refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if refimage.dtype != self.vae.dtype: @@ -169,6 +170,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + return ref_image_latents def prepare_ref_image(