@@ -4175,6 +4175,11 @@ class ExternKernelNode:
4175
4175
node : export_schema .Node
4176
4176
4177
4177
4178
+ fbcode_use_proxy_executor = {
4179
+ torch .ops .aten ._scaled_dot_product_efficient_attention .default ,
4180
+ }
4181
+
4182
+
4178
4183
@dataclasses .dataclass
4179
4184
class FallbackKernel (ExternKernelAlloc ):
4180
4185
def __init__ (
@@ -4208,41 +4213,43 @@ def __init__(
4208
4213
),
4209
4214
), f"Fails to create FallbackKernel for { kernel } : { type (kernel )} not supported"
4210
4215
4211
- if kernel .__module__ == "torch._ops.aten" :
4212
- op_base_name = (
4213
- kernel .__name__ .split ("." )[0 ]
4214
- if isinstance (kernel , torch ._ops .OpOverload )
4215
- else kernel .__name__
4216
- )
4216
+ if kernel .namespace == "aten" :
4217
+ # Aten Fallback Ops
4218
+ assert isinstance (kernel , torch ._ops .OpOverload )
4219
+ op_base_name = kernel .__name__ .split ("." )[0 ]
4220
+
4217
4221
if V .graph .cpp_wrapper :
4218
- assert isinstance (kernel , torch ._ops .OpOverload )
4219
- # Calling with the default kernel name can lead to ambiguous behavior like the following example.
4220
- # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4221
- # repeat_interleave(const at::Tensor & self, int64_t repeats,
4222
- # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4223
- self .kernel = (
4224
- f"at::{ op_base_name } "
4225
- if kernel ._overloadname == "default"
4226
- else f"at::_ops::{ kernel .__name__ .replace ('.' , '_' )} ::call"
4227
- )
4228
- schema = kernel ._schema
4222
+ if config .is_fbcode () and kernel in fbcode_use_proxy_executor :
4223
+ self .use_cpp_op_schema = True
4224
+ self .set_cpp_kernel (kernel )
4225
+ else :
4226
+ # Calling with the default kernel name can lead to ambiguous behavior like the following example.
4227
+ # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4228
+ # repeat_interleave(const at::Tensor & self, int64_t repeats,
4229
+ # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4230
+ self .kernel = (
4231
+ f"at::{ op_base_name } "
4232
+ if kernel ._overloadname == "default"
4233
+ else f"at::_ops::{ kernel .__name__ .replace ('.' , '_' )} ::call"
4234
+ )
4235
+ schema = kernel ._schema
4236
+
4237
+ self .args_default_value = [
4238
+ {"type" : x .real_type , "value" : x .default_value }
4239
+ for x in schema .arguments
4240
+ if not x .kwarg_only
4241
+ ]
4242
+ self .ordered_kwargs_for_cpp_kernel = [
4243
+ x .name for x in schema .arguments if x .kwarg_only
4244
+ ]
4245
+ self .kwargs_default_value = {
4246
+ x .name : {"type" : x .real_type , "value" : x .default_value }
4247
+ for x in schema .arguments
4248
+ if x .kwarg_only
4249
+ }
4229
4250
else :
4230
4251
self .kernel = f"aten.{ op_base_name } "
4231
4252
4232
- if schema is not None :
4233
- self .args_default_value = [
4234
- {"type" : x .real_type , "value" : x .default_value }
4235
- for x in schema .arguments
4236
- if not x .kwarg_only
4237
- ]
4238
- self .ordered_kwargs_for_cpp_kernel = [
4239
- x .name for x in schema .arguments if x .kwarg_only
4240
- ]
4241
- self .kwargs_default_value = {
4242
- x .name : {"type" : x .real_type , "value" : x .default_value }
4243
- for x in schema .arguments
4244
- if x .kwarg_only
4245
- }
4246
4253
elif isinstance (kernel , torch ._ops .HigherOrderOperator ):
4247
4254
if getattr (torch ._prims .rng_prims , kernel .__name__ , None ) is kernel :
4248
4255
self .kernel = f"torch._prims.rng_prims.{ kernel .__name__ } "
@@ -4251,6 +4258,7 @@ def __init__(
4251
4258
"Unable to find HigherOrderOperator kernel name"
4252
4259
)
4253
4260
else :
4261
+ # For non-aten OpOverload, i.e. custom ops
4254
4262
if V .graph .cpp_wrapper :
4255
4263
self .use_cpp_op_schema = True
4256
4264
self .set_cpp_kernel (kernel )
@@ -4408,7 +4416,7 @@ def handle_single_output(return_type, output):
4408
4416
]
4409
4417
)
4410
4418
else :
4411
- raise RuntimeError ("Unsupported return type" )
4419
+ raise RuntimeError (f "Unsupported return type { type ( return_type ) } " )
4412
4420
4413
4421
target = self .op_overload
4414
4422
returns = target ._schema .returns
0 commit comments