Skip to content

Commit 6c5f0de

Browse files
haofanwangResearcherXmansayakpaulyiyixuxu
authored
Support latents_mean and latents_std (#7132)
* update latents_mean and latents_std * fix typos * Update src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py * format --------- Co-authored-by: ResearcherXman <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent e64fdcf commit 6c5f0de

File tree

5 files changed

+81
-8
lines changed

5 files changed

+81
-8
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,22 @@ def __call__(
14601460
self.upcast_vae()
14611461
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
14621462

1463-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1463+
# unscale/denormalize the latents
1464+
# denormalize with the mean and std if available and not None
1465+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1466+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1467+
if has_latents_mean and has_latents_std:
1468+
latents_mean = (
1469+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1470+
)
1471+
latents_std = (
1472+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1473+
)
1474+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1475+
else:
1476+
latents = latents / self.vae.config.scaling_factor
1477+
1478+
image = self.vae.decode(latents, return_dict=False)[0]
14641479

14651480
# cast back to fp16 if needed
14661481
if needs_upcasting:

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,22 @@ def __call__(
15871587
self.upcast_vae()
15881588
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
15891589

1590-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1590+
# unscale/denormalize the latents
1591+
# denormalize with the mean and std if available and not None
1592+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1593+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1594+
if has_latents_mean and has_latents_std:
1595+
latents_mean = (
1596+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1597+
)
1598+
latents_std = (
1599+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1600+
)
1601+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1602+
else:
1603+
latents = latents / self.vae.config.scaling_factor
1604+
1605+
image = self.vae.decode(latents, return_dict=False)[0]
15911606

15921607
# cast back to fp16 if needed
15931608
if needs_upcasting:

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,14 +1404,28 @@ def denoising_value_valid(dnv):
14041404
self.upcast_vae()
14051405
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
14061406

1407-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1407+
# unscale/denormalize the latents
1408+
# denormalize with the mean and std if available and not None
1409+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1410+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1411+
if has_latents_mean and has_latents_std:
1412+
latents_mean = (
1413+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1414+
)
1415+
latents_std = (
1416+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1417+
)
1418+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1419+
else:
1420+
latents = latents / self.vae.config.scaling_factor
1421+
1422+
image = self.vae.decode(latents, return_dict=False)[0]
14081423

14091424
# cast back to fp16 if needed
14101425
if needs_upcasting:
14111426
self.vae.to(dtype=torch.float16)
14121427
else:
14131428
image = latents
1414-
return StableDiffusionXLPipelineOutput(images=image)
14151429

14161430
# apply watermark if available
14171431
if self.watermark is not None:

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,22 @@ def denoising_value_valid(dnv):
17711771
self.upcast_vae()
17721772
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
17731773

1774-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1774+
# unscale/denormalize the latents
1775+
# denormalize with the mean and std if available and not None
1776+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1777+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1778+
if has_latents_mean and has_latents_std:
1779+
latents_mean = (
1780+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1781+
)
1782+
latents_std = (
1783+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1784+
)
1785+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1786+
else:
1787+
latents = latents / self.vae.config.scaling_factor
1788+
1789+
image = self.vae.decode(latents, return_dict=False)[0]
17751790

17761791
# cast back to fp16 if needed
17771792
if needs_upcasting:

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -958,14 +958,28 @@ def __call__(
958958
self.upcast_vae()
959959
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
960960

961-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
961+
# unscale/denormalize the latents
962+
# denormalize with the mean and std if available and not None
963+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
964+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
965+
if has_latents_mean and has_latents_std:
966+
latents_mean = (
967+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
968+
)
969+
latents_std = (
970+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
971+
)
972+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
973+
else:
974+
latents = latents / self.vae.config.scaling_factor
975+
976+
image = self.vae.decode(latents, return_dict=False)[0]
962977

963978
# cast back to fp16 if needed
964979
if needs_upcasting:
965980
self.vae.to(dtype=torch.float16)
966981
else:
967-
image = latents
968-
return StableDiffusionXLPipelineOutput(images=image)
982+
return StableDiffusionXLPipelineOutput(images=latents)
969983

970984
# apply watermark if available
971985
if self.watermark is not None:

0 commit comments

Comments
 (0)