Skip to content

Commit f90a5f8

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][refactor][1/n] Rename cpp_kernel to cpp_kernel_name (pytorch#115783)
Differential Revision: [D52142184](https://our.internmc.facebook.com/intern/diff/D52142184) Pull Request resolved: pytorch#115783 Approved by: https://github.com/chenyang78, https://github.com/jansel
1 parent 1b85992 commit f90a5f8

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

torch/_inductor/ir.py

+32-32
Original file line numberDiff line numberDiff line change
@@ -3902,7 +3902,7 @@ def codegen(self, wrapper):
39023902
self.output_view,
39033903
self.codegen_reference(),
39043904
args,
3905-
self.cpp_kernel if V.graph.cpp_wrapper else self.kernel,
3905+
self.cpp_kernel_name if V.graph.cpp_wrapper else self.kernel,
39063906
)
39073907

39083908
def __init__(
@@ -3913,7 +3913,7 @@ def __init__(
39133913
kwargs=None,
39143914
output_view=None,
39153915
kernel=None,
3916-
cpp_kernel=None,
3916+
cpp_kernel_name=None,
39173917
ordered_kwargs_for_cpp_kernel=(),
39183918
):
39193919
super().__init__(
@@ -3922,7 +3922,7 @@ def __init__(
39223922
self.output_view = output_view
39233923
self.name = V.graph.register_buffer(self)
39243924
self.kernel = kernel
3925-
self.cpp_kernel = cpp_kernel
3925+
self.cpp_kernel_name = cpp_kernel_name
39263926
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
39273927

39283928
def should_allocate(self):
@@ -3941,13 +3941,13 @@ def __init__(self, count: int, device: torch.device):
39413941
inputs=[],
39423942
constant_args=[limits.min, limits.max, [count]],
39433943
kernel="aten.randint.low_out",
3944-
cpp_kernel="at::randint_out",
3944+
cpp_kernel_name="at::randint_out",
39453945
)
39463946

39473947

39483948
class ExternKernelAlloc(ExternKernel):
39493949
def codegen_kernel_name(self):
3950-
return self.cpp_kernel if V.graph.cpp_wrapper else self.kernel
3950+
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.kernel
39513951

39523952
def codegen(self, wrapper):
39533953
self.codegen_comment(wrapper)
@@ -3963,15 +3963,15 @@ def __init__(
39633963
constant_args=(),
39643964
kwargs=None,
39653965
kernel=None,
3966-
cpp_kernel=None,
3966+
cpp_kernel_name=None,
39673967
ordered_kwargs_for_cpp_kernel=(),
39683968
):
39693969
super().__init__(
39703970
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
39713971
)
39723972
self.name = V.graph.register_buffer(self)
39733973
self.kernel = kernel
3974-
self.cpp_kernel = cpp_kernel
3974+
self.cpp_kernel_name = cpp_kernel_name
39753975
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
39763976

39773977
def should_allocate(self):
@@ -4182,7 +4182,7 @@ def codegen(self, wrapper):
41824182
get_operator_enum = {"add": "sum", "multiply": "prod"}
41834183
if reduce in get_operator_enum:
41844184
reduce = get_operator_enum[reduce]
4185-
self.cpp_kernel = self.get_cpp_kernel(self.fn, reduce)
4185+
self.cpp_kernel_name = self.get_cpp_kernel(self.fn, reduce)
41864186

41874187
if self.src_is_tensor:
41884188
(x, index, src) = (t.codegen_reference() for t in self.inputs)
@@ -4192,7 +4192,7 @@ def codegen(self, wrapper):
41924192
wrapper.generate_scatter_fallback(
41934193
x,
41944194
[x, self.constant_args[0], index, src],
4195-
self.cpp_kernel if V.graph.cpp_wrapper else self.kernel,
4195+
self.cpp_kernel_name if V.graph.cpp_wrapper else self.kernel,
41964196
self.fn,
41974197
self.src_is_tensor,
41984198
reduce,
@@ -4281,7 +4281,7 @@ def codegen(self, wrapper):
42814281
args = [x, indices_str, values, *self.codegen_const_args()]
42824282
wrapper.writeline(
42834283
wrapper.wrap_kernel_call(
4284-
self.cpp_kernel if V.graph.cpp_wrapper else self.kernel, args
4284+
self.cpp_kernel_name if V.graph.cpp_wrapper else self.kernel, args
42854285
)
42864286
)
42874287

@@ -4305,7 +4305,7 @@ def __init__(self, x, indices, values, accumulate):
43054305
(accumulate,),
43064306
)
43074307
self.name = V.graph.register_buffer(self)
4308-
self.cpp_kernel = "at::index_put_"
4308+
self.cpp_kernel_name = "at::index_put_"
43094309
self.kernel = "aten.index_put_"
43104310
mark_node_as_mutating(self, x)
43114311

@@ -4452,10 +4452,10 @@ def is_not_write(arg):
44524452
is_not_write(x) for x in kernel._schema.returns
44534453
), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper"
44544454

4455-
self.cpp_kernel = kernel._schema.name
4455+
self.cpp_kernel_name = kernel._schema.name
44564456
self.cpp_kernel_overload_name = kernel._schema.overload_name
44574457
self.cpp_kernel_key = (
4458-
f"{self.cpp_kernel.replace('::', '_')}_{self.cpp_kernel_overload_name}"
4458+
f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"
44594459
)
44604460

44614461
self.cpp_op_schema = get_cpp_op_schema(kernel)
@@ -4489,14 +4489,14 @@ def sdpa_ver_fn():
44894489
self.get_kwargs_value(arg_name) is None
44904490
for arg_name in self.ordered_kwargs_for_cpp_kernel
44914491
):
4492-
return f"{self.cpp_kernel}_v2"
4492+
return f"{self.cpp_kernel_name}_v2"
44934493
else:
4494-
return self.cpp_kernel
4494+
return self.cpp_kernel_name
44954495

44964496
kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
4497-
if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None:
4497+
if (ver_fn := kernel_to_ver.get(self.cpp_kernel_name, None)) is not None:
44984498
return ver_fn()
4499-
return self.cpp_kernel
4499+
return self.cpp_kernel_name
45004500

45014501
def codegen_args(self):
45024502
@dataclasses.dataclass
@@ -4667,7 +4667,7 @@ def codegen(self, wrapper):
46674667
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
46684668
# repeat_interleave(const at::Tensor & self, int64_t repeats,
46694669
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4670-
self.cpp_kernel = (
4670+
self.cpp_kernel_name = (
46714671
f"at::{op_base_name}"
46724672
if kernel._overloadname == "default"
46734673
else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
@@ -5121,7 +5121,7 @@ def __init__(
51215121
constant_args,
51225122
None,
51235123
kernel="torch.ops.mkldnn._convolution_pointwise",
5124-
cpp_kernel="mkldnn::_convolution_pointwise",
5124+
cpp_kernel_name="mkldnn::_convolution_pointwise",
51255125
)
51265126
self.cpp_kernel_key = "convolution_pointwise"
51275127
self.cpp_op_schema = """
@@ -5140,7 +5140,7 @@ def __init__(
51405140
def codegen(self, wrapper):
51415141
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
51425142
self.get_name(),
5143-
self.cpp_kernel if V.graph.cpp_wrapper else self.kernel,
5143+
self.cpp_kernel_name if V.graph.cpp_wrapper else self.kernel,
51445144
self.codegen_args(),
51455145
self.cpp_op_schema,
51465146
self.cpp_kernel_key,
@@ -5191,7 +5191,7 @@ def __init__(
51915191
constant_args,
51925192
None,
51935193
kernel="torch.ops.mkldnn._convolution_pointwise.binary",
5194-
cpp_kernel="mkldnn::_convolution_pointwise",
5194+
cpp_kernel_name="mkldnn::_convolution_pointwise",
51955195
)
51965196
self.cpp_kernel_overload_name = "binary"
51975197
self.cpp_kernel_key = "convolution_pointwise_binary"
@@ -5281,7 +5281,7 @@ def __init__(
52815281
constant_args,
52825282
None,
52835283
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
5284-
cpp_kernel="mkldnn::_convolution_pointwise_",
5284+
cpp_kernel_name="mkldnn::_convolution_pointwise_",
52855285
)
52865286
self.cpp_kernel_overload_name = "binary"
52875287
self.cpp_kernel_key = "convolution_pointwise_binary_"
@@ -5377,7 +5377,7 @@ def __init__(
53775377
constant_args,
53785378
None,
53795379
kernel="torch.ops.mkl._mkl_linear",
5380-
cpp_kernel="mkl::_mkl_linear",
5380+
cpp_kernel_name="mkl::_mkl_linear",
53815381
)
53825382
self.cpp_kernel_key = "mkl_linear"
53835383
self.cpp_op_schema = """
@@ -5430,7 +5430,7 @@ def __init__(
54305430
constant_args,
54315431
None,
54325432
kernel="torch.ops.mkldnn._linear_pointwise",
5433-
cpp_kernel="mkldnn::_linear_pointwise",
5433+
cpp_kernel_name="mkldnn::_linear_pointwise",
54345434
)
54355435
self.cpp_kernel_key = "linear_pointwise"
54365436
self.cpp_op_schema = """
@@ -5495,7 +5495,7 @@ def __init__(
54955495
constant_args,
54965496
None,
54975497
kernel="torch.ops.mkldnn._linear_pointwise.binary",
5498-
cpp_kernel="mkldnn::_linear_pointwise",
5498+
cpp_kernel_name="mkldnn::_linear_pointwise",
54995499
)
55005500
self.cpp_kernel_overload_name = "binary"
55015501
self.cpp_kernel_key = "linear_pointwise_binary"
@@ -5562,7 +5562,7 @@ def __init__(
55625562
constant_args,
55635563
None,
55645564
kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
5565-
cpp_kernel="mkldnn::_convolution_transpose_pointwise",
5565+
cpp_kernel_name="mkldnn::_convolution_transpose_pointwise",
55665566
)
55675567
self.cpp_kernel_key = "convolution_transpose_pointwise"
55685568
self.cpp_op_schema = """
@@ -5646,7 +5646,7 @@ def __init__(
56465646
constant_args,
56475647
None,
56485648
kernel="aten.mkldnn_rnn_layer",
5649-
cpp_kernel="at::mkldnn_rnn_layer",
5649+
cpp_kernel_name="at::mkldnn_rnn_layer",
56505650
)
56515651

56525652
@classmethod
@@ -5766,7 +5766,7 @@ def __init__(
57665766
constant_args,
57675767
None,
57685768
kernel="torch.ops.onednn.qconv2d_pointwise",
5769-
cpp_kernel="onednn::qconv2d_pointwise",
5769+
cpp_kernel_name="onednn::qconv2d_pointwise",
57705770
)
57715771
self.cpp_kernel_key = "qconv2d_pointwise"
57725772
self.cpp_op_schema = """
@@ -5936,7 +5936,7 @@ def __init__(
59365936
constant_args,
59375937
None,
59385938
kernel="torch.ops.onednn.qconv2d_pointwise.binary",
5939-
cpp_kernel="onednn::qconv2d_pointwise",
5939+
cpp_kernel_name="onednn::qconv2d_pointwise",
59405940
)
59415941
self.cpp_kernel_overload_name = "binary"
59425942
self.cpp_kernel_key = "qconv2d_pointwise_binary"
@@ -6136,7 +6136,7 @@ def __init__(
61366136
constant_args,
61376137
None,
61386138
kernel="torch.ops.onednn.qlinear_pointwise",
6139-
cpp_kernel="onednn::qlinear_pointwise",
6139+
cpp_kernel_name="onednn::qlinear_pointwise",
61406140
)
61416141
self.cpp_kernel_key = "qlinear_pointwise"
61426142
self.cpp_op_schema = """
@@ -7223,10 +7223,10 @@ def has_side_effects(self):
72237223
def set_cpp_kernel(self, kernel):
72247224
from .codegen.wrapper import get_cpp_op_schema
72257225

7226-
self.cpp_kernel = kernel._schema.name
7226+
self.cpp_kernel_name = kernel._schema.name
72277227
self.cpp_kernel_overload_name = kernel._schema.overload_name
72287228
self.cpp_kernel_key = (
7229-
f"{self.cpp_kernel.replace('::', '_')}_{self.cpp_kernel_overload_name}"
7229+
f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"
72307230
)
72317231

72327232
self.cpp_op_schema = get_cpp_op_schema(kernel)

torch/_inductor/select_algorithm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def __init__(
559559
assert callable(kernel)
560560
assert not hasattr(extern_kernels, name), "duplicate extern kernel"
561561
self.name = name
562-
self.cpp_kernel = cpp_kernel
562+
self.cpp_kernel_name = cpp_kernel
563563
self.has_out_variant = has_out_variant
564564
setattr(extern_kernels, name, kernel)
565565

@@ -687,7 +687,7 @@ def output_node(self):
687687
layout=self.layout,
688688
inputs=self.input_nodes,
689689
kernel=self.choice.call_name(),
690-
cpp_kernel=self.choice.cpp_kernel,
690+
cpp_kernel_name=self.choice.cpp_kernel_name,
691691
ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
692692
kwargs=self.kwargs,
693693
)

0 commit comments

Comments
 (0)