-
Couldn't load subscription status.
- Fork 10.4k
Support Multi/InfiniteTalk #10179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Support Multi/InfiniteTalk #10179
Changes from all commits
efe83f5
460ce7f
6f6db12
00c069d
57567bd
9c5022e
d0dce6b
7842a5c
99dc959
4cbc1a6
f5d53f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,7 +87,7 @@ def qkv_fn_k(x): | |
| ) | ||
|
|
||
| x = self.o(x) | ||
| return x | ||
| return x, q, k | ||
|
|
||
|
|
||
| class WanT2VCrossAttention(WanSelfAttention): | ||
|
|
@@ -178,7 +178,8 @@ def __init__(self, | |
| window_size=(-1, -1), | ||
| qk_norm=True, | ||
| cross_attn_norm=False, | ||
| eps=1e-6, operation_settings={}): | ||
| eps=1e-6, operation_settings={}, | ||
| block_idx=None): | ||
| super().__init__() | ||
| self.dim = dim | ||
| self.ffn_dim = ffn_dim | ||
|
|
@@ -187,6 +188,7 @@ def __init__(self, | |
| self.qk_norm = qk_norm | ||
| self.cross_attn_norm = cross_attn_norm | ||
| self.eps = eps | ||
| self.block_idx = block_idx | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of having the |
||
|
|
||
| # layers | ||
| self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
|
|
@@ -225,14 +227,16 @@ def forward( | |
| """ | ||
| # assert e.dtype == torch.float32 | ||
|
|
||
| patches = transformer_options.get("patches", {}) | ||
|
|
||
| if e.ndim < 4: | ||
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) | ||
| else: | ||
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) | ||
| # assert e[0].dtype == torch.float32 | ||
|
|
||
| # self-attention | ||
| y = self.self_attn( | ||
| y, q, k = self.self_attn( | ||
| torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), | ||
| freqs, transformer_options=transformer_options) | ||
|
|
||
|
|
@@ -241,6 +245,11 @@ def forward( | |
|
|
||
| # cross-attention & ffn | ||
| x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) | ||
|
|
||
| if "cross_attn" in patches: | ||
| for p in patches["cross_attn"]: | ||
| x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options}) | ||
|
|
||
| y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) | ||
| x = torch.addcmul(x, y, repeat_e(e[5], x)) | ||
| return x | ||
|
|
@@ -262,6 +271,7 @@ def __init__( | |
| ): | ||
| super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) | ||
| self.block_id = block_id | ||
| self.block_idx = None | ||
| if block_id == 0: | ||
| self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
| self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
|
|
@@ -486,8 +496,8 @@ def __init__(self, | |
| cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' | ||
| self.blocks = nn.ModuleList([ | ||
| wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, | ||
| window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) | ||
| for _ in range(num_layers) | ||
| window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i) | ||
| for i in range(num_layers) | ||
| ]) | ||
|
|
||
| # head | ||
|
|
@@ -540,6 +550,7 @@ def forward_orig( | |
| # embeddings | ||
| x = self.patch_embedding(x.float()).to(x.dtype) | ||
| grid_sizes = x.shape[2:] | ||
| transformer_options["grid_sizes"] = grid_sizes | ||
| x = x.flatten(2).transpose(1, 2) | ||
|
|
||
| # time embeddings | ||
|
|
@@ -722,6 +733,7 @@ def forward_orig( | |
| # embeddings | ||
| x = self.patch_embedding(x.float()).to(x.dtype) | ||
| grid_sizes = x.shape[2:] | ||
| transformer_options["grid_sizes"] = grid_sizes | ||
| x = x.flatten(2).transpose(1, 2) | ||
|
|
||
| # time embeddings | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some uncertainty about whether returning this will in general increase the memory peak of WAN within native ComfyUI. Instead, comfy suggests that you add a patch to replace the
x = optimized_attention(...)call on line 81 byreusing theModelPatcher.set_model_attn1_replacefunctionality (in unet, attn1 is self, attn2 is cross), which can then do the optimized_attention call + the partial attention thing that happens inside thecross_attnpatch. To get the q + k for thecross_attnpatch, you can store the q and k values in transformer_options instead and then pop them out after usage.The
transformer_indexcan stay None (not given) since that was something unique to unet models.It would probably be more optimal to not call optimized_attention anymore and just reuse the logic of hte slower partial attention thingy in this code, but comfy said he would be fine if you didn't go that far and just kept both within that attention replacement function.