@@ -362,64 +362,13 @@ def fn2(x):
362
362
self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
363
363
self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
364
364
365
- @requires_gpu ()
366
- @config .patch ({"fx_graph_cache" : True })
367
- @config .patch ({"fx_graph_remote_cache" : False })
368
- def test_flex_attention_caching (self ):
369
- from torch .nn .attention .flex_attention import create_block_mask , flex_attention
370
-
371
- block_mask = create_block_mask (
372
- lambda b , h , q , kv : q >= kv , None , None , 2048 , 2048
373
- )
374
-
375
- def score_mod (score , b , h , q , kv ):
376
- return score + (q - kv )
377
-
378
- def fn (q , k , v ):
379
- return flex_attention (q , k , v , score_mod = score_mod , block_mask = block_mask )
380
-
381
- def score_mod2 (score , b , h , q , kv ):
382
- return score
383
-
384
- def fn2 (q , k , v ):
385
- return flex_attention (q , k , v , score_mod = score_mod2 , block_mask = block_mask )
386
-
387
- a , b , c = (torch .randn (1 , 4 , 512 , 64 ).cuda () for _ in range (3 ))
388
- compiled_fn = torch .compile (fn )
389
- compiled_fn2 = torch .compile (fn2 )
390
-
391
- # A first call should miss in the cache.
392
- self .assertEqual (fn (a , b , c ), compiled_fn (a , b , c ))
393
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
394
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
395
- self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 0 )
396
-
397
- # A second call should hit. (First reset so in-memory guards
398
- # don't prevent compilation).
399
- for m in torch ._inductor .codecache .PyCodeCache .cache .values ():
400
- os .remove (m .__file__ )
401
- self .reset ()
402
- self .assertEqual (fn (a , b , c ), compiled_fn (a , b , c ))
403
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
404
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
405
- self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 1 )
406
-
407
- # A third call with different score_mod should have a cache miss
408
- for m in torch ._inductor .codecache .PyCodeCache .cache .values ():
409
- os .remove (m .__file__ )
410
- self .reset ()
411
- self .assertEqual (fn2 (a , b , c ), compiled_fn2 (a , b , c ))
412
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
413
- self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
414
- self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 1 )
415
-
416
365
@requires_gpu ()
417
366
@requires_triton ()
418
367
@config .patch ({"fx_graph_cache" : True })
419
368
@config .patch ({"fx_graph_remote_cache" : False })
420
- def test_triton_higher_order_op_bypass (self ):
369
+ def test_higher_order_op_bypass (self ):
421
370
"""
422
- Verify that we bypass the cache when we have a triton higher order ops.
371
+ Verify that we bypass the cache when we have higher order ops.
423
372
"""
424
373
425
374
def fn (x , y ):
0 commit comments