Skip to content

Commit 728ed37

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
[AOTInductor] Allow using ProxyExecutor for ATen fallbacks (pytorch#112976)
Summary: Use ProxyExecutor for aten._scaled_dot_product_efficient_attention in ABI-mode Test Plan: OSS CI Differential Revision: D51005807 Pull Request resolved: pytorch#112976 Approved by: https://github.com/chenyang78, https://github.com/jansel
1 parent df4f0b3 commit 728ed37

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

torch/_inductor/ir.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4175,6 +4175,11 @@ class ExternKernelNode:
41754175
node: export_schema.Node
41764176

41774177

4178+
fbcode_use_proxy_executor = {
4179+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
4180+
}
4181+
4182+
41784183
@dataclasses.dataclass
41794184
class FallbackKernel(ExternKernelAlloc):
41804185
def __init__(
@@ -4208,41 +4213,43 @@ def __init__(
42084213
),
42094214
), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
42104215

4211-
if kernel.__module__ == "torch._ops.aten":
4212-
op_base_name = (
4213-
kernel.__name__.split(".")[0]
4214-
if isinstance(kernel, torch._ops.OpOverload)
4215-
else kernel.__name__
4216-
)
4216+
if kernel.namespace == "aten":
4217+
# Aten Fallback Ops
4218+
assert isinstance(kernel, torch._ops.OpOverload)
4219+
op_base_name = kernel.__name__.split(".")[0]
4220+
42174221
if V.graph.cpp_wrapper:
4218-
assert isinstance(kernel, torch._ops.OpOverload)
4219-
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
4220-
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4221-
# repeat_interleave(const at::Tensor & self, int64_t repeats,
4222-
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4223-
self.kernel = (
4224-
f"at::{op_base_name}"
4225-
if kernel._overloadname == "default"
4226-
else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
4227-
)
4228-
schema = kernel._schema
4222+
if config.is_fbcode() and kernel in fbcode_use_proxy_executor:
4223+
self.use_cpp_op_schema = True
4224+
self.set_cpp_kernel(kernel)
4225+
else:
4226+
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
4227+
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4228+
# repeat_interleave(const at::Tensor & self, int64_t repeats,
4229+
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4230+
self.kernel = (
4231+
f"at::{op_base_name}"
4232+
if kernel._overloadname == "default"
4233+
else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
4234+
)
4235+
schema = kernel._schema
4236+
4237+
self.args_default_value = [
4238+
{"type": x.real_type, "value": x.default_value}
4239+
for x in schema.arguments
4240+
if not x.kwarg_only
4241+
]
4242+
self.ordered_kwargs_for_cpp_kernel = [
4243+
x.name for x in schema.arguments if x.kwarg_only
4244+
]
4245+
self.kwargs_default_value = {
4246+
x.name: {"type": x.real_type, "value": x.default_value}
4247+
for x in schema.arguments
4248+
if x.kwarg_only
4249+
}
42294250
else:
42304251
self.kernel = f"aten.{op_base_name}"
42314252

4232-
if schema is not None:
4233-
self.args_default_value = [
4234-
{"type": x.real_type, "value": x.default_value}
4235-
for x in schema.arguments
4236-
if not x.kwarg_only
4237-
]
4238-
self.ordered_kwargs_for_cpp_kernel = [
4239-
x.name for x in schema.arguments if x.kwarg_only
4240-
]
4241-
self.kwargs_default_value = {
4242-
x.name: {"type": x.real_type, "value": x.default_value}
4243-
for x in schema.arguments
4244-
if x.kwarg_only
4245-
}
42464253
elif isinstance(kernel, torch._ops.HigherOrderOperator):
42474254
if getattr(torch._prims.rng_prims, kernel.__name__, None) is kernel:
42484255
self.kernel = f"torch._prims.rng_prims.{kernel.__name__}"
@@ -4251,6 +4258,7 @@ def __init__(
42514258
"Unable to find HigherOrderOperator kernel name"
42524259
)
42534260
else:
4261+
# For non-aten OpOverload, i.e. custom ops
42544262
if V.graph.cpp_wrapper:
42554263
self.use_cpp_op_schema = True
42564264
self.set_cpp_kernel(kernel)
@@ -4408,7 +4416,7 @@ def handle_single_output(return_type, output):
44084416
]
44094417
)
44104418
else:
4411-
raise RuntimeError("Unsupported return type")
4419+
raise RuntimeError(f"Unsupported return type {type(return_type)}")
44124420

44134421
target = self.op_overload
44144422
returns = target._schema.returns

0 commit comments

Comments
 (0)