Skip to content

Commit

Permalink
Fix pipeline dtype unexpected change when using SDXL reference commun…
Browse files Browse the repository at this point in the history
…ity pipelines in float16 mode (#10670)

Fix pipeline dtype unexpected change when using SDXL reference community pipelines
  • Loading branch information
dimitribarbot authored Jan 28, 2025
1 parent 7b100ce commit 196aef5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi

def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
Expand Down Expand Up @@ -223,6 +224,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do

# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)

# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)

return ref_image_latents

def prepare_ref_image(
Expand Down
8 changes: 7 additions & 1 deletion examples/community/stable_diffusion_xl_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def retrieve_timesteps(
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
Expand Down Expand Up @@ -169,6 +170,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do

# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)

# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)

return ref_image_latents

def prepare_ref_image(
Expand Down

0 comments on commit 196aef5

Please sign in to comment.