|
2 | 2 | import math |
3 | 3 |
|
4 | 4 | from .model import QwenImageTransformer2DModel |
| 5 | +from .model import QwenImageTransformerBlock |
| 6 | + |
| 7 | + |
| 8 | +class QwenImageFunControlBlock(QwenImageTransformerBlock): |
| 9 | + def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): |
| 10 | + super().__init__( |
| 11 | + dim=dim, |
| 12 | + num_attention_heads=num_attention_heads, |
| 13 | + attention_head_dim=attention_head_dim, |
| 14 | + dtype=dtype, |
| 15 | + device=device, |
| 16 | + operations=operations, |
| 17 | + ) |
| 18 | + self.has_before_proj = has_before_proj |
| 19 | + if has_before_proj: |
| 20 | + self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) |
| 21 | + self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) |
| 22 | + |
| 23 | + |
| 24 | +class QwenImageFunControlNetModel(torch.nn.Module): |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + control_in_features=132, |
| 28 | + inner_dim=3072, |
| 29 | + num_attention_heads=24, |
| 30 | + attention_head_dim=128, |
| 31 | + num_control_blocks=5, |
| 32 | + main_model_double=60, |
| 33 | + injection_layers=(0, 12, 24, 36, 48), |
| 34 | + dtype=None, |
| 35 | + device=None, |
| 36 | + operations=None, |
| 37 | + ): |
| 38 | + super().__init__() |
| 39 | + self.dtype = dtype |
| 40 | + self.main_model_double = main_model_double |
| 41 | + self.injection_layers = tuple(injection_layers) |
| 42 | + # Keep base hint scaling at 1.0 so user-facing strength behaves similarly |
| 43 | + # to the reference Gen2/VideoX implementation around strength=1. |
| 44 | + self.hint_scale = 1.0 |
| 45 | + self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) |
| 46 | + |
| 47 | + self.control_blocks = torch.nn.ModuleList([]) |
| 48 | + for i in range(num_control_blocks): |
| 49 | + self.control_blocks.append( |
| 50 | + QwenImageFunControlBlock( |
| 51 | + dim=inner_dim, |
| 52 | + num_attention_heads=num_attention_heads, |
| 53 | + attention_head_dim=attention_head_dim, |
| 54 | + has_before_proj=(i == 0), |
| 55 | + dtype=dtype, |
| 56 | + device=device, |
| 57 | + operations=operations, |
| 58 | + ) |
| 59 | + ) |
| 60 | + |
| 61 | + def _process_hint_tokens(self, hint): |
| 62 | + if hint is None: |
| 63 | + return None |
| 64 | + if hint.ndim == 4: |
| 65 | + hint = hint.unsqueeze(2) |
| 66 | + |
| 67 | + # Fun checkpoints are trained with 33 latent channels before 2x2 packing: |
| 68 | + # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. |
| 69 | + # Default behavior (no inpaint input in stock Apply ControlNet) should use |
| 70 | + # zeros for mask/inpaint branches, matching VideoX fallback semantics. |
| 71 | + expected_c = self.control_img_in.weight.shape[1] // 4 |
| 72 | + if hint.shape[1] == 16 and expected_c == 33: |
| 73 | + zeros_mask = torch.zeros_like(hint[:, :1]) |
| 74 | + zeros_inpaint = torch.zeros_like(hint) |
| 75 | + hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) |
| 76 | + |
| 77 | + bs, c, t, h, w = hint.shape |
| 78 | + hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) |
| 79 | + orig_shape = hidden_states.shape |
| 80 | + hidden_states = hidden_states.view( |
| 81 | + orig_shape[0], |
| 82 | + orig_shape[1], |
| 83 | + orig_shape[-3], |
| 84 | + orig_shape[-2] // 2, |
| 85 | + 2, |
| 86 | + orig_shape[-1] // 2, |
| 87 | + 2, |
| 88 | + ) |
| 89 | + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) |
| 90 | + hidden_states = hidden_states.reshape( |
| 91 | + bs, |
| 92 | + t * ((h + 1) // 2) * ((w + 1) // 2), |
| 93 | + c * 4, |
| 94 | + ) |
| 95 | + |
| 96 | + expected_in = self.control_img_in.weight.shape[1] |
| 97 | + cur_in = hidden_states.shape[-1] |
| 98 | + if cur_in < expected_in: |
| 99 | + pad = torch.zeros( |
| 100 | + (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), |
| 101 | + device=hidden_states.device, |
| 102 | + dtype=hidden_states.dtype, |
| 103 | + ) |
| 104 | + hidden_states = torch.cat([hidden_states, pad], dim=-1) |
| 105 | + elif cur_in > expected_in: |
| 106 | + hidden_states = hidden_states[:, :, :expected_in] |
| 107 | + |
| 108 | + return hidden_states |
| 109 | + |
| 110 | + def forward( |
| 111 | + self, |
| 112 | + x, |
| 113 | + timesteps, |
| 114 | + context, |
| 115 | + attention_mask=None, |
| 116 | + guidance: torch.Tensor = None, |
| 117 | + hint=None, |
| 118 | + transformer_options={}, |
| 119 | + base_model=None, |
| 120 | + **kwargs, |
| 121 | + ): |
| 122 | + if base_model is None: |
| 123 | + raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") |
| 124 | + |
| 125 | + encoder_hidden_states_mask = attention_mask |
| 126 | + # Keep attention mask disabled inside Fun control blocks to mirror |
| 127 | + # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). |
| 128 | + encoder_hidden_states_mask = None |
| 129 | + |
| 130 | + hidden_states, img_ids, _ = base_model.process_img(x) |
| 131 | + hint_tokens = self._process_hint_tokens(hint) |
| 132 | + if hint_tokens is None: |
| 133 | + raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") |
| 134 | + |
| 135 | + if hint_tokens.shape[1] != hidden_states.shape[1]: |
| 136 | + max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) |
| 137 | + hint_tokens = hint_tokens[:, :max_tokens] |
| 138 | + hidden_states = hidden_states[:, :max_tokens] |
| 139 | + img_ids = img_ids[:, :max_tokens] |
| 140 | + |
| 141 | + txt_start = round( |
| 142 | + max( |
| 143 | + ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, |
| 144 | + ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, |
| 145 | + ) |
| 146 | + ) |
| 147 | + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) |
| 148 | + ids = torch.cat((txt_ids, img_ids), dim=1) |
| 149 | + image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() |
| 150 | + |
| 151 | + hidden_states = base_model.img_in(hidden_states) |
| 152 | + encoder_hidden_states = base_model.txt_norm(context) |
| 153 | + encoder_hidden_states = base_model.txt_in(encoder_hidden_states) |
| 154 | + |
| 155 | + if guidance is not None: |
| 156 | + guidance = guidance * 1000 |
| 157 | + |
| 158 | + temb = ( |
| 159 | + base_model.time_text_embed(timesteps, hidden_states) |
| 160 | + if guidance is None |
| 161 | + else base_model.time_text_embed(timesteps, guidance, hidden_states) |
| 162 | + ) |
| 163 | + |
| 164 | + c = self.control_img_in(hint_tokens) |
| 165 | + |
| 166 | + for i, block in enumerate(self.control_blocks): |
| 167 | + if i == 0: |
| 168 | + c_in = block.before_proj(c) + hidden_states |
| 169 | + all_c = [] |
| 170 | + else: |
| 171 | + all_c = list(torch.unbind(c, dim=0)) |
| 172 | + c_in = all_c.pop(-1) |
| 173 | + |
| 174 | + encoder_hidden_states, c_out = block( |
| 175 | + hidden_states=c_in, |
| 176 | + encoder_hidden_states=encoder_hidden_states, |
| 177 | + encoder_hidden_states_mask=encoder_hidden_states_mask, |
| 178 | + temb=temb, |
| 179 | + image_rotary_emb=image_rotary_emb, |
| 180 | + transformer_options=transformer_options, |
| 181 | + ) |
| 182 | + |
| 183 | + c_skip = block.after_proj(c_out) * self.hint_scale |
| 184 | + all_c += [c_skip, c_out] |
| 185 | + c = torch.stack(all_c, dim=0) |
| 186 | + |
| 187 | + hints = torch.unbind(c, dim=0)[:-1] |
| 188 | + |
| 189 | + controlnet_block_samples = [None] * self.main_model_double |
| 190 | + for local_idx, base_idx in enumerate(self.injection_layers): |
| 191 | + if local_idx < len(hints) and base_idx < len(controlnet_block_samples): |
| 192 | + controlnet_block_samples[base_idx] = hints[local_idx] |
| 193 | + |
| 194 | + return {"input": controlnet_block_samples} |
5 | 195 |
|
6 | 196 |
|
7 | 197 | class QwenImageControlNetModel(QwenImageTransformer2DModel): |
|
0 commit comments