Skip to content

Commit e9bfbf7

Browse files
Revert "Allow fx graph caching higher order operators (opt-in) (pytorch#135877)"
This reverts commit 66d5eb6. Reverted pytorch#135877 on behalf of https://github.com/jeanschmidt due to seems to have introduced regressions on rocm signals ([comment](pytorch#135877 (comment)))
1 parent 75f141b commit e9bfbf7

File tree

6 files changed

+20
-79
lines changed

6 files changed

+20
-79
lines changed

test/inductor/test_codecache.py

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -362,64 +362,13 @@ def fn2(x):
362362
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
363363
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
364364

365-
@requires_gpu()
366-
@config.patch({"fx_graph_cache": True})
367-
@config.patch({"fx_graph_remote_cache": False})
368-
def test_flex_attention_caching(self):
369-
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
370-
371-
block_mask = create_block_mask(
372-
lambda b, h, q, kv: q >= kv, None, None, 2048, 2048
373-
)
374-
375-
def score_mod(score, b, h, q, kv):
376-
return score + (q - kv)
377-
378-
def fn(q, k, v):
379-
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
380-
381-
def score_mod2(score, b, h, q, kv):
382-
return score
383-
384-
def fn2(q, k, v):
385-
return flex_attention(q, k, v, score_mod=score_mod2, block_mask=block_mask)
386-
387-
a, b, c = (torch.randn(1, 4, 512, 64).cuda() for _ in range(3))
388-
compiled_fn = torch.compile(fn)
389-
compiled_fn2 = torch.compile(fn2)
390-
391-
# A first call should miss in the cache.
392-
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c))
393-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
394-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
395-
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
396-
397-
# A second call should hit. (First reset so in-memory guards
398-
# don't prevent compilation).
399-
for m in torch._inductor.codecache.PyCodeCache.cache.values():
400-
os.remove(m.__file__)
401-
self.reset()
402-
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c))
403-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
404-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
405-
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
406-
407-
# A third call with different score_mod should have a cache miss
408-
for m in torch._inductor.codecache.PyCodeCache.cache.values():
409-
os.remove(m.__file__)
410-
self.reset()
411-
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c))
412-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
413-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
414-
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
415-
416365
@requires_gpu()
417366
@requires_triton()
418367
@config.patch({"fx_graph_cache": True})
419368
@config.patch({"fx_graph_remote_cache": False})
420-
def test_triton_higher_order_op_bypass(self):
369+
def test_higher_order_op_bypass(self):
421370
"""
422-
Verify that we bypass the cache when we have a triton higher order ops.
371+
Verify that we bypass the cache when we have higher order ops.
423372
"""
424373

425374
def fn(x, y):

torch/_higher_order_ops/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
8787

8888
class FlexAttentionHOP(HigherOrderOperator):
8989
def __init__(self) -> None:
90-
super().__init__("flex_attention", cacheable=True)
90+
super().__init__("flex_attention")
9191

9292
def __call__(
9393
self,

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def identify_mutated_tensors(kernel, kwargs):
523523
# Used for wrapping a Triton Kernel
524524
class TritonKernelWrapperMutation(HigherOrderOperator):
525525
def __init__(self) -> None:
526-
super().__init__("triton_kernel_wrapper_mutation", cacheable=False)
526+
super().__init__("triton_kernel_wrapper_mutation")
527527

528528
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
529529
return super().__call__(
@@ -540,7 +540,7 @@ def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
540540
# Used for wrapping a Triton Kernel in a functional manner
541541
class TritonKernelWrapperFunctional(HigherOrderOperator):
542542
def __init__(self) -> None:
543-
super().__init__("triton_kernel_wrapper_functional", cacheable=False)
543+
super().__init__("triton_kernel_wrapper_functional")
544544

545545
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
546546
return super().__call__(

torch/_higher_order_ops/wrap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
107107
"""
108108

109109
def __init__(self) -> None:
110-
super().__init__("wrap_activation_checkpoint", cacheable=False)
110+
super().__init__("wrap_activation_checkpoint")
111111

112112
def __call__(self, function, *args, **kwargs):
113113
# use_reentrant is set to False because this op is going to be traced.
@@ -146,7 +146,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
146146
"""
147147

148148
def __init__(self) -> None:
149-
super().__init__("tag_activation_checkpoint", cacheable=False)
149+
super().__init__("tag_activation_checkpoint")
150150

151151
@staticmethod
152152
def divide_kwargs(kwargs):

torch/_inductor/codecache.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,22 +1268,18 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
12681268
log.debug("fx graph cache no shape env")
12691269
raise BypassFxGraphCache("No shape env")
12701270

1271-
# We skip caching if there are any torchbind objects.
1272-
for module in gm.modules():
1273-
if not isinstance(module, torch.fx.GraphModule):
1274-
continue
1275-
for node in module.graph.nodes:
1276-
if (
1277-
isinstance(node.target, torch._ops.HigherOrderOperator)
1278-
and not node.target.cacheable()
1279-
):
1280-
raise BypassFxGraphCache(
1281-
f"Can't cache HigherOrderOperator: {node.target.name()}"
1282-
)
1283-
if node.op == "getattr" and isinstance(
1284-
getattr(gm, node.target), torch._C.ScriptObject
1285-
):
1286-
raise BypassFxGraphCache("Can't cache torchbind objects")
1271+
# HigherOrderOperators should be handled on a case-by-case basis.
1272+
# Currently, we just skip caching if we have any.
1273+
# We also skip if there are any torchbind objects.
1274+
for node in gm.graph.nodes:
1275+
if isinstance(node.target, torch._ops.HigherOrderOperator):
1276+
raise BypassFxGraphCache(
1277+
f"Can't cache HigherOrderOperator: {node.target.name()}"
1278+
)
1279+
if node.op == "getattr" and isinstance(
1280+
getattr(gm, node.target), torch._C.ScriptObject
1281+
):
1282+
raise BypassFxGraphCache("Can't cache torchbind objects")
12871283

12881284
@staticmethod
12891285
def prepare_key(

torch/_ops.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
245245
# If you're creating a new HigherOrderOperator, please do not change the
246246
# default. Adding operators to the global torch.ops namespace is a bad
247247
# practice due to name collisions.
248-
def __init__(self, name, *, cacheable=False):
248+
def __init__(self, name):
249249
super().__init__()
250250
if type(self) is HigherOrderOperator:
251251
raise RuntimeError(
@@ -258,7 +258,6 @@ def __init__(self, name, *, cacheable=False):
258258
_higher_order_ops[name] = self
259259
self._ns = "higher_order"
260260
self.__module__ = "torch.ops.higher_order"
261-
self._cacheable = cacheable
262261

263262
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
264263

@@ -282,9 +281,6 @@ def py_impl(self, k):
282281
def namespace(self):
283282
return self._ns
284283

285-
def cacheable(self):
286-
return self._cacheable
287-
288284
def fallthrough(self, dispatch_key):
289285
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
290286

0 commit comments

Comments
 (0)