Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add Quanto backend #10756

Merged
merged 41 commits into from
Mar 10, 2025
Merged

[Quantization] Add Quanto backend #10756

merged 41 commits into from
Mar 10, 2025

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Feb 10, 2025

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested review from a-r-r-o-w and sayakpaul and removed request for a-r-r-o-w February 10, 2025 07:26
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this? Last time I checked only weight-quantized models were compatible with torch.compile. Cc: @dacorvo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but this should be fixed in pytorch 2.6 (I did not check though).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dacorvo I tried to run torch compile with float8 weights in the following way and hit an error during inference

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qfloat8

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
pipe = FluxPipeline.from_pretrained(
    model_id, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")

Traceback:

  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2082, in validate
    raise AssertionError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16
), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072
), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16,
           requires_grad=True))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., de
vice='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3
072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/optimum/quanto/nn/qlinear.py", line 50, in forward
    return torch.nn.functional.linear(input, self.qweight, bias=self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

The torch.compile step seems to work. The error is raised during the forward pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with nightly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same errors with nightly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be specific that only int8 supports torch.compile for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned in the compile section of the docs

from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights="float8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a comment to note that only weights will be quantized.

Comment on lines 56 to 58
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh lovely. Not to digress from this PR but would it make sense to also do something similar for bitsandbytes and torchao for from_single_file() or not yet?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAO should just work out of the box. We can add a section in the docs page.

For BnB the conversion step in single file is still a bottleneck. We need to figure out how to handle that gracefully.

- int4
- int2

### Activations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's show an example from the docs as well?

Additionally, we could refer the users to this blog post so that they have a sense of the savings around memory and latency?

@@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model,
state_dict,
device=param_device,
dtype=torch_dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going away?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a mistake. Thanks for catching.


def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list):
# Quanto imports diffusers internally. These are placed here to avoid circular imports
from optimum.quanto import QLinear, qfloat8, qint2, qint4, qint8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quanto does support Conv layers, though. Should we consider them in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a model that benefits from quantized Conv layers? I recall that it didn't work so great for SD UNets?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a note could be nice if we're not adding QConv. In my experiments, I was able to quantize the VAE and obtain decent results.

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Feb 20, 2025
@DN6
Copy link
Collaborator Author

DN6 commented Feb 25, 2025

@dacorvo I think there might be an issue with qint4 and torch 2.6. I tried running this snippet

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qfloat8

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
quantize(transformer, weights=qint4)
freeze(transformer)

And hit this error:

Traceback (most recent call last):
  File "/home/dhruv/diffusers/../scripts/test_quanto_flux_compile.py", line 12, in <module>
    freeze(transformer)
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/quantize.py", line 146, in freeze
    m.freeze()
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 301, in freeze
    qweight = self.qweight
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 267, in qweight
    return quantize_weight(
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/quantization.py", line 73, in quantize_weight
    return WeightQBitsTensor.quantize(t, qtype, axis, group_size, scale, shift, optimized)
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbits.py", line 147, in quantize
    return WeightsQBitsQuantizer.apply(base, qtype, axis, group_size, scale, shift, optimized)
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbits.py", line 55, in forward
    return WeightQBitsTensor.create(qtype, axis, group_size, size, stride, data, scale, shift)
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbits.py", line 117, in create
    return TinyGemmWeightQBitsTensor(
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/tinygemm/qbits.py", line 84, in __init__
    self._data = TinyGemmPackedTensor.pack(ungrouped)
  File "/home/dhruv/miniconda3/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/tinygemm/packed.py", line 59, in pack
    data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)
NotImplementedError: Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this
backend, or was omitted during the selective/custom build process (if using custom build)

Issue isn't there with torch 2.5.1, I assume because qint4 uses PackedTensor for the weight for torch<=2.6?

@DN6 DN6 requested a review from sayakpaul March 7, 2025 03:22
@DN6
Copy link
Collaborator Author

DN6 commented Mar 7, 2025

@sayakpaul can you take another look here. I had to remove the activation quantization option since there isn't really a good way to run the calibration step with the Diffusers API.

So for now we will only offer weights only quantization. We can revisit if there is demand for activation quantization.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Looks very close to merge.


- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be specific that only int8 supports torch.compile for now?


model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subfolder missing.


model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subfolder missing.

quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that this works.

Comment on lines +249 to +250
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would the param be handled in that case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't wouldn't apply any casting and the parameter would be loaded as is into the model.

import torch
import torch.nn as nn

class LoRALayer(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to me apart from the nits I mentioned on the docs. Thanks, Dhruv!

@DN6 DN6 merged commit f5edaa7 into main Mar 10, 2025
29 of 30 checks passed
@vladmandic
Copy link
Contributor

one issue - diffusers.QuantoConfig is different than transformers.utils.quantization_config.QuantoConfig so we cannot use same config to load both transformer and text-encoder (since text encoder is loaded using transformers.T5EncoderModel.from_pretrained)

i don't see any particular reason why config is different, only key thing that is missing is activations property - even if its not implemented in diffusers, it would be good to have configs compatible between diffusers and transformers

sayakpaul added a commit that referenced this pull request Mar 20, 2025
Co-authored-by: SunMarc <[email protected]>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <[email protected]>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <[email protected]>

* update

* update

---------

Co-authored-by: Sayak Paul <[email protected]>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <[email protected]>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <[email protected]>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <[email protected]>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <[email protected]>

* update

---------

Co-authored-by: Sayak Paul <[email protected]>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <[email protected]>

making ```formatted_images``` initialization compact (#10801)

compact writing

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <[email protected]>
Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <[email protected]>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <[email protected]>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <[email protected]>

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <[email protected]>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <[email protected]>

---------

Co-authored-by: DN6 <[email protected]>

CogView4 Control Block (#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>

LTX 0.9.5 (#10968)

* update

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)

Co-authored-by: Juan Acevedo <[email protected]>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

7 participants