Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
Expand All @@ -147,7 +148,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
Expand Down Expand Up @@ -552,9 +554,11 @@ class WanTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Reference: https://github.com/huggingface/diffusers/pull/12909
Copy link
Member

@sayakpaul sayakpaul Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this specific to I2V only? If so, then this change is probably a little to intrusive no?

Copy link
Contributor Author

@DefTruth DefTruth Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul This is theoretically applicable to all wan series models and offers better performance. I conducted tests on the wan 2.1/2.2 t2v/i2v models, and the results were all correct. You can perform quick verification through examples in cache-dit. This patch has already been used in cache-dit to fix some precision issues. vipshop/cache-dit#639

pip3 install torch==2.9.1 transformers accelerate torchao bitsandbytes torchvision 
pip3 install opencv-python-headless einops imageio-ffmpeg ftfy 
pip3 install git+https://github.com/huggingface/diffusers.git # latest 
pip3 install git+https://github.com/vipshop/cache-dit.git # latest

git clone https://github.com/vipshop/cache-dit.git && cd cache-dit/examples

# use  `--cpu-offload` and `--parallel-text-encoder` for low VRAM device, e.g, < 48GiB
torchrun --nproc_per_node=4 generate.py wan2.1_t2v --parallel ulysses --parallel-text-encoder
torchrun --nproc_per_node=4 generate.py wan2.2_t2v --parallel ulysses --parallel-text-encoder --cpu-offload
torchrun --nproc_per_node=2 generate.py wan2.1_i2v --parallel ulysses --parallel-text-encoder --steps 16 --frames 21 --vae-tiling
torchrun --nproc_per_node=2 generate.py wan2.2_i2v --parallel ulysses --parallel-text-encoder --cpu-offload --steps 16 --frames 21 --vae-tiling

# We need to disable the splitting of encoder_hidden_states because the image_encoder
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
# —to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/transformers/transformer_wan_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
Expand All @@ -622,7 +623,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
Expand Down
Loading