2020
2121from ...configuration_utils import ConfigMixin , register_to_config
2222from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
23- from ...utils import USE_PEFT_BACKEND , BaseOutput , deprecate , logging , scale_lora_layers , unscale_lora_layers
23+ from ...utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
2424from ..attention import AttentionMixin
2525from ..cache_utils import CacheMixin
2626from ..controlnets .controlnet import zero_module
3131 QwenImageTransformerBlock ,
3232 QwenTimestepProjEmbeddings ,
3333 RMSNorm ,
34- compute_text_seq_len_from_mask ,
3534)
3635
3736
@@ -137,7 +136,7 @@ def forward(
137136 return_dict : bool = True ,
138137 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
139138 """
140- The [`QwenImageControlNetModel `] forward method.
139+ The [`FluxTransformer2DModel `] forward method.
141140
142141 Args:
143142 hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -148,39 +147,24 @@ def forward(
148147 The scale factor for ControlNet outputs.
149148 encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
150149 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
151- encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
152- Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
153- Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
154- (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
150+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
151+ from the embeddings of input conditions.
155152 timestep ( `torch.LongTensor`):
156153 Used to indicate denoising step.
157- img_shapes (`List[Tuple[int, int, int]]`, *optional*):
158- Image shapes for RoPE computation.
159- txt_seq_lens (`List[int]`, *optional*):
160- **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
161- length.
154+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
155+ A list of tensors that if specified are added to the residuals of transformer blocks.
162156 joint_attention_kwargs (`dict`, *optional*):
163157 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
164158 `self.processor` in
165159 [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
166160 return_dict (`bool`, *optional*, defaults to `True`):
167- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
161+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
162+ tuple.
168163
169164 Returns:
170- If `return_dict` is True, a [`~models.controlnet.ControlNetOutput `] is returned, otherwise a `tuple` where
171- the first element is the controlnet block samples .
165+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput `] is returned, otherwise a
166+ `tuple` where the first element is the sample tensor .
172167 """
173- # Handle deprecated txt_seq_lens parameter
174- if txt_seq_lens is not None :
175- deprecate (
176- "txt_seq_lens" ,
177- "0.39.0" ,
178- "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
179- "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
180- "and `encoder_hidden_states_mask`." ,
181- standard_warn = False ,
182- )
183-
184168 if joint_attention_kwargs is not None :
185169 joint_attention_kwargs = joint_attention_kwargs .copy ()
186170 lora_scale = joint_attention_kwargs .pop ("scale" , 1.0 )
@@ -202,47 +186,32 @@ def forward(
202186
203187 temb = self .time_text_embed (timestep , hidden_states )
204188
205- # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
206- text_seq_len , _ , encoder_hidden_states_mask = compute_text_seq_len_from_mask (
207- encoder_hidden_states , encoder_hidden_states_mask
208- )
209-
210- image_rotary_emb = self .pos_embed (img_shapes , max_txt_seq_len = text_seq_len , device = hidden_states .device )
189+ image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
211190
212191 timestep = timestep .to (hidden_states .dtype )
213192 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
214193 encoder_hidden_states = self .txt_in (encoder_hidden_states )
215194
216- # Construct joint attention mask once to avoid reconstructing in every block
217- block_attention_kwargs = joint_attention_kwargs .copy () if joint_attention_kwargs is not None else {}
218- if encoder_hidden_states_mask is not None :
219- # Build joint mask: [text_mask, all_ones_for_image]
220- batch_size , image_seq_len = hidden_states .shape [:2 ]
221- image_mask = torch .ones ((batch_size , image_seq_len ), dtype = torch .bool , device = hidden_states .device )
222- joint_attention_mask = torch .cat ([encoder_hidden_states_mask , image_mask ], dim = 1 )
223- block_attention_kwargs ["attention_mask" ] = joint_attention_mask
224-
225195 block_samples = ()
226- for block in self .transformer_blocks :
196+ for index_block , block in enumerate ( self .transformer_blocks ) :
227197 if torch .is_grad_enabled () and self .gradient_checkpointing :
228198 encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
229199 block ,
230200 hidden_states ,
231201 encoder_hidden_states ,
232- None , # Don't pass encoder_hidden_states_mask (using attention_mask instead)
202+ encoder_hidden_states_mask ,
233203 temb ,
234204 image_rotary_emb ,
235- block_attention_kwargs ,
236205 )
237206
238207 else :
239208 encoder_hidden_states , hidden_states = block (
240209 hidden_states = hidden_states ,
241210 encoder_hidden_states = encoder_hidden_states ,
242- encoder_hidden_states_mask = None , # Don't pass (using attention_mask instead)
211+ encoder_hidden_states_mask = encoder_hidden_states_mask ,
243212 temb = temb ,
244213 image_rotary_emb = image_rotary_emb ,
245- joint_attention_kwargs = block_attention_kwargs ,
214+ joint_attention_kwargs = joint_attention_kwargs ,
246215 )
247216 block_samples = block_samples + (hidden_states ,)
248217
@@ -298,15 +267,6 @@ def forward(
298267 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
299268 return_dict : bool = True ,
300269 ) -> Union [QwenImageControlNetOutput , Tuple ]:
301- if txt_seq_lens is not None :
302- deprecate (
303- "txt_seq_lens" ,
304- "0.39.0" ,
305- "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
306- "removed in version 0.39.0. The text sequence length is now automatically inferred from "
307- "`encoder_hidden_states` and `encoder_hidden_states_mask`." ,
308- standard_warn = False ,
309- )
310270 # ControlNet-Union with multiple conditions
311271 # only load one ControlNet for saving memories
312272 if len (self .nets ) == 1 :
@@ -321,6 +281,7 @@ def forward(
321281 encoder_hidden_states_mask = encoder_hidden_states_mask ,
322282 timestep = timestep ,
323283 img_shapes = img_shapes ,
284+ txt_seq_lens = txt_seq_lens ,
324285 joint_attention_kwargs = joint_attention_kwargs ,
325286 return_dict = return_dict ,
326287 )
0 commit comments