@@ -105,7 +105,6 @@ def _check_common(
105
105
):
106
106
self .assertEqual (arg1 .grad , arg2 .grad , atol = atol , rtol = rtol )
107
107
108
- @skipIfRocm
109
108
def _test_sdpa_rewriter_1 (self ):
110
109
def dot_prod_attention (
111
110
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -132,7 +131,6 @@ def dot_prod_attention(
132
131
rtol = rtol ,
133
132
)
134
133
135
- @skipIfRocm
136
134
@torch ._inductor .config .patch ("freezing" , True )
137
135
def _test_sdpa_rewriter_1_freezing (self ):
138
136
def dot_prod_attention (
@@ -264,7 +262,6 @@ def dot_prod_attention(
264
262
_ , (source_code ,) = run_and_get_code (dot_prod_attention , * args )
265
263
self .assertNotIn ("aten._scaled_dot_product_efficient_attention" , source_code )
266
264
267
- @skipIfRocm
268
265
def _test_sdpa_rewriter_2 (self ):
269
266
def dot_prod_attention (
270
267
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -279,7 +276,6 @@ def dot_prod_attention(
279
276
self ._check_common (dot_prod_attention )
280
277
self ._check_common (checkpoint_wrapper (dot_prod_attention ))
281
278
282
- @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
283
279
def _test_sdpa_rewriter_3 (self ):
284
280
def dot_prod_attention (
285
281
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , training : bool
@@ -296,7 +292,6 @@ def dot_prod_attention(
296
292
checkpoint_wrapper (dot_prod_attention ), contains = False , has_dropout = True
297
293
)
298
294
299
- @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
300
295
def _test_sdpa_rewriter_4 (self ):
301
296
def dot_prod_attention (
302
297
query : torch .Tensor ,
@@ -346,7 +341,6 @@ def sfdp_pattern_5_v2(query, key, value):
346
341
self ._check_common (sfdp_pattern_5_v2 , contains = False )
347
342
self ._check_common (checkpoint_wrapper (sfdp_pattern_5_v2 ), contains = False )
348
343
349
- @skipIfRocm
350
344
def _test_sdpa_rewriter_6 (self ):
351
345
def sfdp_pattern_6 (query , key , value , training ):
352
346
attn_mask = torch .ones (
@@ -570,7 +564,6 @@ def forward(self, query, key, value, attn_mask) -> torch.Tensor:
570
564
model , args1 = args , contains = False , atol = 1e-4 , has_fuse_pattern = False
571
565
)
572
566
573
- @skipIfRocm
574
567
def _test_sdpa_rewriter_11 (self ):
575
568
def dot_prod_attention (
576
569
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -611,7 +604,6 @@ def dot_prod_attention(
611
604
612
605
self ._check_common (dot_prod_attention , contains = False , has_dropout = True )
613
606
614
- @skipIfRocm
615
607
def _test_sdpa_prev_13 (self ):
616
608
def dot_prod_attention (
617
609
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -628,7 +620,6 @@ def dot_prod_attention(
628
620
self ._check_common (dot_prod_attention , check_train = False )
629
621
self ._check_common (checkpoint_wrapper (dot_prod_attention ), check_train = False )
630
622
631
- @skipIfRocm
632
623
def _test_sdpa_prev_14 (self ):
633
624
def dot_prod_attention (
634
625
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -644,7 +635,6 @@ def dot_prod_attention(
644
635
self ._check_common (dot_prod_attention , check_train = False )
645
636
self ._check_common (checkpoint_wrapper (dot_prod_attention ), check_train = False )
646
637
647
- @skipIfRocm
648
638
def _test_sdpa_prev_15 (self ):
649
639
def dot_prod_attention (
650
640
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -694,7 +684,6 @@ def dot_prod_attention(
694
684
rtol = 1e-2 ,
695
685
)
696
686
697
- @skipIfRocm
698
687
def _test_sdpa_rewriter_14 (self ):
699
688
def dot_prod_attention (
700
689
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -717,7 +706,6 @@ def dot_prod_attention(
717
706
718
707
self ._check_common (dot_prod_attention )
719
708
720
- @skipIfRocm
721
709
def _test_sdpa_rewriter_15 (self ):
722
710
def dot_prod_attention (
723
711
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -810,7 +798,6 @@ def dot_prod_attention(
810
798
dot_prod_attention , args1 = args , contains = False , has_dropout = True
811
799
)
812
800
813
- @skipIfRocm
814
801
def _test_sdpa_rewriter_17 (self ):
815
802
def dot_prod_attention (
816
803
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , training
0 commit comments