Skip to content

Commit 68180d0

Browse files
AmdSampsajithunnair-amd
authored andcommitted
[ROCm] Update meta_registration for efficient attention (pytorch#146979)
Fixes a series of failing and skipped unit tests. For nvidia hw, the longsumexp last dimension is required to be a multiple of 32. This is not the case for rocm. A related issue: pytorch#146848 The unit tests in question: ```bash inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_13_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_11_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_17_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_1_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_1_freezing inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_2_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_3_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_4_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_6_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_13_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_11_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_17_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_1_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_1_freezing inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_2_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_3_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_4_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_6_cuda ``` Pull Request resolved: pytorch#146979 Approved by: https://github.com/shunting314
1 parent a36323a commit 68180d0

File tree

4 files changed

+10
-29
lines changed

4 files changed

+10
-29
lines changed

test/inductor/test_fused_attention.py

-13
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _check_common(
105105
):
106106
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
107107

108-
@skipIfRocm
109108
def _test_sdpa_rewriter_1(self):
110109
def dot_prod_attention(
111110
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -132,7 +131,6 @@ def dot_prod_attention(
132131
rtol=rtol,
133132
)
134133

135-
@skipIfRocm
136134
@torch._inductor.config.patch("freezing", True)
137135
def _test_sdpa_rewriter_1_freezing(self):
138136
def dot_prod_attention(
@@ -264,7 +262,6 @@ def dot_prod_attention(
264262
_, (source_code,) = run_and_get_code(dot_prod_attention, *args)
265263
self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code)
266264

267-
@skipIfRocm
268265
def _test_sdpa_rewriter_2(self):
269266
def dot_prod_attention(
270267
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -279,7 +276,6 @@ def dot_prod_attention(
279276
self._check_common(dot_prod_attention)
280277
self._check_common(checkpoint_wrapper(dot_prod_attention))
281278

282-
@skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
283279
def _test_sdpa_rewriter_3(self):
284280
def dot_prod_attention(
285281
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool
@@ -296,7 +292,6 @@ def dot_prod_attention(
296292
checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True
297293
)
298294

299-
@skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
300295
def _test_sdpa_rewriter_4(self):
301296
def dot_prod_attention(
302297
query: torch.Tensor,
@@ -346,7 +341,6 @@ def sfdp_pattern_5_v2(query, key, value):
346341
self._check_common(sfdp_pattern_5_v2, contains=False)
347342
self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False)
348343

349-
@skipIfRocm
350344
def _test_sdpa_rewriter_6(self):
351345
def sfdp_pattern_6(query, key, value, training):
352346
attn_mask = torch.ones(
@@ -570,7 +564,6 @@ def forward(self, query, key, value, attn_mask) -> torch.Tensor:
570564
model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False
571565
)
572566

573-
@skipIfRocm
574567
def _test_sdpa_rewriter_11(self):
575568
def dot_prod_attention(
576569
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -611,7 +604,6 @@ def dot_prod_attention(
611604

612605
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
613606

614-
@skipIfRocm
615607
def _test_sdpa_prev_13(self):
616608
def dot_prod_attention(
617609
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -628,7 +620,6 @@ def dot_prod_attention(
628620
self._check_common(dot_prod_attention, check_train=False)
629621
self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
630622

631-
@skipIfRocm
632623
def _test_sdpa_prev_14(self):
633624
def dot_prod_attention(
634625
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -644,7 +635,6 @@ def dot_prod_attention(
644635
self._check_common(dot_prod_attention, check_train=False)
645636
self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
646637

647-
@skipIfRocm
648638
def _test_sdpa_prev_15(self):
649639
def dot_prod_attention(
650640
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -694,7 +684,6 @@ def dot_prod_attention(
694684
rtol=1e-2,
695685
)
696686

697-
@skipIfRocm
698687
def _test_sdpa_rewriter_14(self):
699688
def dot_prod_attention(
700689
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -717,7 +706,6 @@ def dot_prod_attention(
717706

718707
self._check_common(dot_prod_attention)
719708

720-
@skipIfRocm
721709
def _test_sdpa_rewriter_15(self):
722710
def dot_prod_attention(
723711
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -810,7 +798,6 @@ def dot_prod_attention(
810798
dot_prod_attention, args1=args, contains=False, has_dropout=True
811799
)
812800

813-
@skipIfRocm
814801
def _test_sdpa_rewriter_17(self):
815802
def dot_prod_attention(
816803
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training

test/inductor/test_torchinductor.py

-4
Original file line numberDiff line numberDiff line change
@@ -10590,10 +10590,6 @@ def fn(z):
1059010590
def test_scaled_dot_product_attention(self):
1059110591
if self.device == "cuda" and not PLATFORM_SUPPORTS_FLASH_ATTENTION:
1059210592
raise unittest.SkipTest("Can't run flash attention on this platform")
10593-
if self.device == "cuda" and TEST_WITH_ROCM:
10594-
raise unittest.SkipTest(
10595-
"Flash attention support is incomplete on this platform"
10596-
)
1059710593

1059810594
def fn(q, k, v):
1059910595
return torch.nn.functional.scaled_dot_product_attention(

torch/_inductor/fx_passes/fuse_attention.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import math
66

77
import torch
8-
from torch.nn.attention import sdpa_kernel, SDPBackend
98

109
from ..._dynamo.utils import counters
1110
from ..pattern_matcher import (
@@ -20,14 +19,7 @@
2019
aten = torch.ops.aten
2120

2221

23-
if torch.version.hip:
24-
25-
def _scaled_dot_product_attention(*args, **kwargs):
26-
with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
27-
return aten.scaled_dot_product_attention(*args, **kwargs)
28-
29-
else:
30-
_scaled_dot_product_attention = aten.scaled_dot_product_attention
22+
_scaled_dot_product_attention = aten.scaled_dot_product_attention
3123

3224

3325
def _sfdp_pattern_1(query, key, value, inv_scale):

torch/_meta_registrations.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -4050,8 +4050,7 @@ def pool3d_shape_check(
40504050
torch._check(
40514051
dT > 0 and dW > 0 and dH > 0,
40524052
lambda: (
4053-
f"stride should be greater than zero, but got "
4054-
f"dT: {dT}, dH: {dH}, dW: {dW}"
4053+
f"stride should be greater than zero, but got dT: {dT}, dH: {dH}, dW: {dW}"
40554054
),
40564055
)
40574056
torch._check(
@@ -5330,7 +5329,14 @@ def meta__scaled_dot_product_efficient_attention(
53305329

53315330
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
53325331

5333-
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5332+
if torch.version.hip and torch.cuda.is_available():
5333+
"""Please see: https://github.com/pytorch/pytorch/issues/146848
5334+
longsumexp last dim should be seq length
5335+
"""
5336+
logsumexp_dim = M if compute_log_sumexp else 0
5337+
else:
5338+
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5339+
53345340
logsum_exp = torch.empty(
53355341
(B, num_heads, logsumexp_dim),
53365342
dtype=torch.float,

0 commit comments

Comments
 (0)