Skip to content

Commit 6b77b72

Browse files
committed
Revert " Fix QwenImage txt_seq_lens handling (huggingface#12702)"
This reverts commit dad5cb5.
1 parent 8daaa6b commit 6b77b72

17 files changed

+172
-513
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -108,46 +108,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
108108
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
109109
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
110110
image = pipe(
111-
image=[image_1, image_2],
112-
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
111+
image=[image_1, image_2],
112+
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
113113
num_inference_steps=50
114114
).images[0]
115115
```
116116

117-
## Performance
118-
119-
### torch.compile
120-
121-
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
122-
123-
```python
124-
import torch
125-
from diffusers import QwenImagePipeline
126-
127-
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
128-
pipe.transformer = torch.compile(pipe.transformer)
129-
130-
# First call triggers compilation (~7s overhead)
131-
# Subsequent calls run at ~2.4x faster
132-
image = pipe("a cat", num_inference_steps=50).images[0]
133-
```
134-
135-
### Batched Inference with Variable-Length Prompts
136-
137-
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
138-
139-
```python
140-
# CFG with different prompt lengths works correctly
141-
image = pipe(
142-
prompt="A cat",
143-
negative_prompt="blurry, low quality, distorted",
144-
true_cfg_scale=3.5,
145-
num_inference_steps=50,
146-
).images[0]
147-
```
148-
149-
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
150-
151117
## QwenImagePipeline
152118

153119
[[autodoc]] QwenImagePipeline

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,12 +1513,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15131513
height=model_input.shape[3],
15141514
width=model_input.shape[4],
15151515
)
1516+
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
15161517
model_pred = transformer(
15171518
hidden_states=packed_noisy_model_input,
15181519
encoder_hidden_states=prompt_embeds,
15191520
encoder_hidden_states_mask=prompt_embeds_mask,
15201521
timestep=timesteps / 1000,
15211522
img_shapes=img_shapes,
1523+
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
15221524
return_dict=False,
15231525
)[0]
15241526
model_pred = QwenImagePipeline._unpack_latents(

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,43 +2128,6 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
21282128
return out
21292129

21302130

2131-
def _prepare_additive_attn_mask(
2132-
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
2133-
) -> torch.Tensor:
2134-
"""
2135-
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
2136-
2137-
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
2138-
2139-
Args:
2140-
attn_mask: 2D tensor [batch_size, seq_len_k]
2141-
- Boolean: True means attend, False means mask out
2142-
- Additive: 0.0 means attend, -inf means mask out
2143-
target_dtype: The dtype to convert the mask to (usually query.dtype)
2144-
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
2145-
2146-
Returns:
2147-
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
2148-
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
2149-
"""
2150-
# Check if the mask is boolean or already additive
2151-
if attn_mask.dtype == torch.bool:
2152-
# Convert boolean to additive: True -> 0.0, False -> -inf
2153-
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
2154-
# Convert to target dtype
2155-
attn_mask = attn_mask.to(dtype=target_dtype)
2156-
else:
2157-
# Already additive mask - just ensure correct dtype
2158-
attn_mask = attn_mask.to(dtype=target_dtype)
2159-
2160-
# Optionally reshape to 4D for broadcasting in attention mechanisms
2161-
if reshape_4d:
2162-
batch_size, seq_len_k = attn_mask.shape
2163-
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
2164-
2165-
return attn_mask
2166-
2167-
21682131
@_AttentionBackendRegistry.register(
21692132
AttentionBackendName.NATIVE,
21702133
constraints=[_check_device, _check_shape],
@@ -2184,19 +2147,6 @@ def _native_attention(
21842147
) -> torch.Tensor:
21852148
if return_lse:
21862149
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
2187-
2188-
# Reshape 2D mask to 4D for SDPA
2189-
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
2190-
if (
2191-
attn_mask is not None
2192-
and attn_mask.ndim == 2
2193-
and attn_mask.shape[0] == query.shape[0]
2194-
and attn_mask.shape[1] == key.shape[1]
2195-
):
2196-
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
2197-
# SDPA handles both boolean and additive masks correctly
2198-
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
2199-
22002150
if _parallel_config is None:
22012151
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
22022152
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2763,34 +2713,10 @@ def _xformers_attention(
27632713
attn_mask = xops.LowerTriangularMask()
27642714
elif attn_mask is not None:
27652715
if attn_mask.ndim == 2:
2766-
# Convert 2D mask to 4D for xformers
2767-
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
2768-
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
2769-
# Need memory alignment - create larger tensor and slice for alignment
2770-
original_seq_len = attn_mask.size(1)
2771-
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
2772-
2773-
# Create aligned 4D tensor and slice to ensure proper memory layout
2774-
aligned_mask = torch.zeros(
2775-
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
2776-
dtype=query.dtype,
2777-
device=query.device,
2778-
)
2779-
# Convert to 4D additive mask (handles both boolean and additive inputs)
2780-
mask_additive = _prepare_additive_attn_mask(
2781-
attn_mask, target_dtype=query.dtype
2782-
) # [batch, 1, 1, seq_len_k]
2783-
# Broadcast to [batch, heads, seq_q, seq_len_k]
2784-
aligned_mask[:, :, :, :original_seq_len] = mask_additive
2785-
# Mask out the padding (already -inf from zeros -> where with default)
2786-
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
2787-
2788-
# Slice to actual size with proper alignment
2789-
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
2716+
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
27902717
elif attn_mask.ndim != 4:
27912718
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
2792-
elif attn_mask.ndim == 4:
2793-
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
2719+
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
27942720

27952721
if enable_gqa:
27962722
if num_heads_q % num_heads_kv != 0:

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
2424
from ..attention import AttentionMixin
2525
from ..cache_utils import CacheMixin
2626
from ..controlnets.controlnet import zero_module
@@ -31,7 +31,6 @@
3131
QwenImageTransformerBlock,
3232
QwenTimestepProjEmbeddings,
3333
RMSNorm,
34-
compute_text_seq_len_from_mask,
3534
)
3635

3736

@@ -137,7 +136,7 @@ def forward(
137136
return_dict: bool = True,
138137
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
139138
"""
140-
The [`QwenImageControlNetModel`] forward method.
139+
The [`FluxTransformer2DModel`] forward method.
141140
142141
Args:
143142
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -148,39 +147,24 @@ def forward(
148147
The scale factor for ControlNet outputs.
149148
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
150149
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
151-
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
152-
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
153-
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
154-
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
150+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
151+
from the embeddings of input conditions.
155152
timestep ( `torch.LongTensor`):
156153
Used to indicate denoising step.
157-
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
158-
Image shapes for RoPE computation.
159-
txt_seq_lens (`List[int]`, *optional*):
160-
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
161-
length.
154+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
155+
A list of tensors that if specified are added to the residuals of transformer blocks.
162156
joint_attention_kwargs (`dict`, *optional*):
163157
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
164158
`self.processor` in
165159
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
166160
return_dict (`bool`, *optional*, defaults to `True`):
167-
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
161+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
162+
tuple.
168163
169164
Returns:
170-
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
171-
the first element is the controlnet block samples.
165+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
166+
`tuple` where the first element is the sample tensor.
172167
"""
173-
# Handle deprecated txt_seq_lens parameter
174-
if txt_seq_lens is not None:
175-
deprecate(
176-
"txt_seq_lens",
177-
"0.39.0",
178-
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
179-
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
180-
"and `encoder_hidden_states_mask`.",
181-
standard_warn=False,
182-
)
183-
184168
if joint_attention_kwargs is not None:
185169
joint_attention_kwargs = joint_attention_kwargs.copy()
186170
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -202,47 +186,32 @@ def forward(
202186

203187
temb = self.time_text_embed(timestep, hidden_states)
204188

205-
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
206-
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
207-
encoder_hidden_states, encoder_hidden_states_mask
208-
)
209-
210-
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
189+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
211190

212191
timestep = timestep.to(hidden_states.dtype)
213192
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
214193
encoder_hidden_states = self.txt_in(encoder_hidden_states)
215194

216-
# Construct joint attention mask once to avoid reconstructing in every block
217-
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
218-
if encoder_hidden_states_mask is not None:
219-
# Build joint mask: [text_mask, all_ones_for_image]
220-
batch_size, image_seq_len = hidden_states.shape[:2]
221-
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
222-
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
223-
block_attention_kwargs["attention_mask"] = joint_attention_mask
224-
225195
block_samples = ()
226-
for block in self.transformer_blocks:
196+
for index_block, block in enumerate(self.transformer_blocks):
227197
if torch.is_grad_enabled() and self.gradient_checkpointing:
228198
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
229199
block,
230200
hidden_states,
231201
encoder_hidden_states,
232-
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
202+
encoder_hidden_states_mask,
233203
temb,
234204
image_rotary_emb,
235-
block_attention_kwargs,
236205
)
237206

238207
else:
239208
encoder_hidden_states, hidden_states = block(
240209
hidden_states=hidden_states,
241210
encoder_hidden_states=encoder_hidden_states,
242-
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
211+
encoder_hidden_states_mask=encoder_hidden_states_mask,
243212
temb=temb,
244213
image_rotary_emb=image_rotary_emb,
245-
joint_attention_kwargs=block_attention_kwargs,
214+
joint_attention_kwargs=joint_attention_kwargs,
246215
)
247216
block_samples = block_samples + (hidden_states,)
248217

@@ -298,15 +267,6 @@ def forward(
298267
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
299268
return_dict: bool = True,
300269
) -> Union[QwenImageControlNetOutput, Tuple]:
301-
if txt_seq_lens is not None:
302-
deprecate(
303-
"txt_seq_lens",
304-
"0.39.0",
305-
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
306-
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
307-
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
308-
standard_warn=False,
309-
)
310270
# ControlNet-Union with multiple conditions
311271
# only load one ControlNet for saving memories
312272
if len(self.nets) == 1:
@@ -321,6 +281,7 @@ def forward(
321281
encoder_hidden_states_mask=encoder_hidden_states_mask,
322282
timestep=timestep,
323283
img_shapes=img_shapes,
284+
txt_seq_lens=txt_seq_lens,
324285
joint_attention_kwargs=joint_attention_kwargs,
325286
return_dict=return_dict,
326287
)

0 commit comments

Comments
 (0)