Skip to content

Commit 803e817

Browse files
iamzoltansayakpaul
andauthored
Add vae slicing and tiling to flux pipeline (#9122)
add vae slicing and tiling to flux pipeline Co-authored-by: Sayak Paul <[email protected]>
1 parent 67f5cce commit 803e817

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,35 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
454454

455455
return latents
456456

457+
def enable_vae_slicing(self):
458+
r"""
459+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
460+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
461+
"""
462+
self.vae.enable_slicing()
463+
464+
def disable_vae_slicing(self):
465+
r"""
466+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
467+
computing decoding in one step.
468+
"""
469+
self.vae.disable_slicing()
470+
471+
def enable_vae_tiling(self):
472+
r"""
473+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
474+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
475+
processing larger images.
476+
"""
477+
self.vae.enable_tiling()
478+
479+
def disable_vae_tiling(self):
480+
r"""
481+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
482+
computing decoding in one step.
483+
"""
484+
self.vae.disable_tiling()
485+
457486
def prepare_latents(
458487
self,
459488
batch_size,

0 commit comments

Comments
 (0)