Skip to content

Commit 3da5a07

Browse files
authored
SDPA backend priority (Comfy-Org#9299)
1 parent afa0a45 commit 3da5a07

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

comfy/ldm/hunyuan3d/vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
178178

179179
class CrossAttentionProcessor:
180180
def __call__(self, attn, q, k, v):
181-
out = F.scaled_dot_product_attention(q, k, v)
181+
out = ops.scaled_dot_product_attention(q, k, v)
182182
return out
183183

184184

comfy/ldm/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
448448
mask = mask.unsqueeze(1)
449449

450450
if SDP_BATCH_LIMIT >= b:
451-
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
451+
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
452452
if not skip_output_reshape:
453453
out = (
454454
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
461461
if mask.shape[0] > 1:
462462
m = mask[i : i + SDP_BATCH_LIMIT]
463463

464-
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
464+
out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention(
465465
q[i : i + SDP_BATCH_LIMIT],
466466
k[i : i + SDP_BATCH_LIMIT],
467467
v[i : i + SDP_BATCH_LIMIT],

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
285285
)
286286

287287
try:
288-
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
288+
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
289289
out = out.transpose(2, 3).reshape(orig_shape)
290290
except model_management.OOM_EXCEPTION:
291291
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")

comfy/ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,18 @@
2323
import comfy.float
2424
import comfy.rmsnorm
2525
import contextlib
26+
from torch.nn.attention import SDPBackend, sdpa_kernel
2627

2728
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
2829

30+
SDPA_BACKEND_PRIORITY = [
31+
SDPBackend.FLASH_ATTENTION,
32+
SDPBackend.EFFICIENT_ATTENTION,
33+
SDPBackend.MATH,
34+
]
35+
if torch.cuda.is_available():
36+
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
37+
2938
def cast_to_input(weight, input, non_blocking=False, copy=True):
3039
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
3140

@@ -249,6 +258,10 @@ def conv_nd(s, dims, *args, **kwargs):
249258
else:
250259
raise ValueError(f"unsupported dimensions: {dims}")
251260

261+
@staticmethod
262+
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
263+
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
264+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
252265

253266
class manual_cast(disable_weight_init):
254267
class Linear(disable_weight_init.Linear):

0 commit comments

Comments
 (0)