Skip to content

Commit bd19d6d

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Use torchgen to generate C shim functions (pytorch#120513)
Summary: The current C shim layer manually implements a C interface for a handful of ops. Obviously that's not scalable if we want to extend it to cover all aten ops. This new torchgen script automatically generates C shim interfaces for CPU and CUDA backends. The interface follows the same parameter passing rules as the current C shim layer, such as * Use plain C data types to pass parameters * Use AtenTensorHandle to pass at::Tensor * Use pointer type to pass optional parameter * Use pointer+length to pass list * Use device_type+device_index to pass device * When a parameter is a pointer of pointer, e.g. AtenTensorHandle**, the script generates either a list of optional values or an optional list of values https://gist.github.com/desertfire/83701532b126c6d34dae6ba68a1b074a is an example of the generated torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp file. The current version doesn't generate C shim wrappers for all aten ops, and probably generates more wrappers than needed on the other hand, but it should serve as a good basis. This PR by itself won't change AOTI codegen and thus won't introduce any FC breakage. The actual wrapper codegen changes will come in another PR with some version control flag to avoid FC breakage. Differential Revision: [D54258087](https://our.internmc.facebook.com/intern/diff/D54258087) Pull Request resolved: pytorch#120513 Approved by: https://github.com/jansel
1 parent ffe45a8 commit bd19d6d

File tree

6 files changed

+611
-8
lines changed

6 files changed

+611
-8
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ torch/csrc/api/include/torch/version.h
8686
torch/csrc/cudnn/cuDNN.cpp
8787
torch/csrc/generated
8888
torch/csrc/generic/TensorMethods.cpp
89+
torch/csrc/inductor/aoti_torch/generated/*
8990
torch/csrc/jit/generated/*
9091
torch/csrc/jit/fuser/config.h
9192
torch/csrc/nn/THCUNN.cpp

caffe2/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
368368
"${TORCH_SRC_DIR}/csrc/autograd/generated/TraceType_4.cpp"
369369
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp"
370370
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp"
371+
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp"
371372
)
372373
if(BUILD_LAZY_TS_BACKEND)
373374
list(APPEND GENERATED_CXX_TORCH
@@ -422,12 +423,17 @@ set(GENERATED_TESTING_PYTHON
422423
"${TORCH_SRC_DIR}/testing/_internal/generated/annotated_fn_args.py"
423424
)
424425

426+
set(GENERATED_CXX_TORCH_CUDA
427+
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
428+
)
429+
425430
set(TORCH_GENERATED_CODE
426431
${GENERATED_CXX_TORCH}
427432
${GENERATED_H_TORCH}
428433
${GENERATED_CXX_PYTHON}
429434
${GENERATED_H_PYTHON}
430435
${GENERATED_TESTING_PYTHON}
436+
${GENERATED_CXX_TORCH_CUDA}
431437
)
432438

433439
set(GEN_PER_OPERATOR_FLAG)
@@ -970,6 +976,7 @@ endif()
970976
# Compile exposed libraries.
971977
if(USE_ROCM)
972978
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
979+
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
973980
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
974981
if(USE_FLASH_ATTENTION)
975982
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
@@ -988,6 +995,7 @@ if(USE_ROCM)
988995
endif()
989996
elseif(USE_CUDA)
990997
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
998+
list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA})
991999
if(CUDA_SEPARABLE_COMPILATION)
9921000
# Separate compilation fails when kernels using `thrust::sort_by_key`
9931001
# are linked with the rest of CUDA code. Workaround by linking them separately.

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,7 @@ def main():
12501250
"include/torch/csrc/inductor/aoti_runtime/*.h",
12511251
"include/torch/csrc/inductor/aoti_torch/*.h",
12521252
"include/torch/csrc/inductor/aoti_torch/c/*.h",
1253+
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
12531254
"include/torch/csrc/jit/*.h",
12541255
"include/torch/csrc/jit/backends/*.h",
12551256
"include/torch/csrc/jit/generated/*.h",

torch/csrc/inductor/aoti_torch/utils.h

+105
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
#pragma once
22

3+
#include <ATen/Tensor.h>
4+
#include <ATen/core/List.h>
5+
#include <c10/core/DeviceType.h>
6+
#include <c10/core/SymIntArrayRef.h>
7+
#include <c10/util/ArrayRef.h>
38
#include <c10/util/Logging.h>
49
#include <c10/util/Optional.h>
10+
#include <c10/util/OptionalArrayRef.h>
511
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
612
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
713

@@ -18,6 +24,8 @@
1824
return AOTI_TORCH_SUCCESS;
1925

2026
namespace torch::aot_inductor {
27+
28+
// utility functions to convert a pointer to an optional value
2129
template <class T>
2230
inline c10::optional<T> pointer_to_optional(T* ptr) {
2331
return ptr ? c10::make_optional(*ptr) : c10::nullopt;
@@ -34,4 +42,101 @@ inline c10::optional<at::Tensor> pointer_to_optional(AtenTensorHandle* ptr) {
3442
: c10::nullopt;
3543
}
3644

45+
template <>
46+
inline c10::optional<at::Tensor> pointer_to_optional(
47+
const AtenTensorHandle* ptr) {
48+
return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
49+
: c10::nullopt;
50+
}
51+
52+
inline c10::optional<c10::Device> pointer_to_optional_device(
53+
int32_t* device_type,
54+
int32_t device_index) {
55+
return device_type ? c10::make_optional(c10::Device(
56+
static_cast<c10::DeviceType>(*device_type),
57+
static_cast<c10::DeviceIndex>(device_index)))
58+
: c10::nullopt;
59+
}
60+
61+
// utility functions to convert a pointer to a list
62+
template <typename T>
63+
struct is_optional : std::false_type {};
64+
template <typename T>
65+
struct is_optional<c10::optional<T>> : std::true_type {};
66+
67+
template <class T>
68+
inline c10::ArrayRef<T> pointer_to_list(T* ptr, int64_t len) {
69+
return c10::ArrayRef<T>(ptr, len);
70+
}
71+
72+
template <
73+
class T,
74+
class U,
75+
typename = std::enable_if_t<!std::is_same_v<T, U>>,
76+
typename = std::enable_if_t<!is_optional<T>::value>>
77+
inline std::vector<T> pointer_to_list(U* ptr, int64_t len) {
78+
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
79+
// site
80+
std::vector<T> result;
81+
result.reserve(len);
82+
for (int64_t i = 0; i < len; i++) {
83+
result.emplace_back(T(ptr[i]));
84+
}
85+
return result;
86+
}
87+
88+
template <class T, class U, typename = std::enable_if_t<is_optional<T>::value>>
89+
inline std::vector<T> pointer_to_list(U** ptr, int64_t len) {
90+
// Here U** denotes a list of optional arguments
91+
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
92+
// site
93+
std::vector<T> result;
94+
result.reserve(len);
95+
for (int64_t i = 0; i < len; i++) {
96+
result.emplace_back(pointer_to_optional(ptr[i]));
97+
}
98+
return result;
99+
}
100+
101+
template <>
102+
inline std::vector<at::Tensor> pointer_to_list(
103+
const AtenTensorHandle* ptr,
104+
int64_t len) {
105+
std::vector<at::Tensor> result;
106+
result.reserve(len);
107+
for (int64_t i = 0; i < len; i++) {
108+
result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
109+
}
110+
return result;
111+
}
112+
113+
template <>
114+
inline std::vector<c10::optional<at::Tensor>> pointer_to_list(
115+
const AtenTensorHandle** ptr,
116+
int64_t len) {
117+
std::vector<c10::optional<at::Tensor>> result;
118+
result.reserve(len);
119+
for (int64_t i = 0; i < len; i++) {
120+
result.emplace_back(pointer_to_optional<at::Tensor>(ptr[i]));
121+
}
122+
return result;
123+
}
124+
125+
template <int N>
126+
inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
127+
std::array<bool, N> result;
128+
std::copy(ptr, ptr + N, result.begin());
129+
return result;
130+
}
131+
132+
// utility functions to convert a pointer to a list of optional values
133+
template <class T, class U>
134+
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
135+
U** ptr,
136+
int64_t len) {
137+
return ptr
138+
? c10::make_optional<c10::ArrayRef<T>>(pointer_to_list<T>(*ptr, len))
139+
: c10::nullopt;
140+
}
141+
37142
} // namespace torch::aot_inductor

torchgen/gen.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
with_native_function,
4545
with_native_function_and_indices,
4646
)
47+
from torchgen.gen_aoti_c_shim import (
48+
gen_aoti_c_shim,
49+
gen_static_dispatch_backend_call_signature,
50+
get_backend_index_for_aoti,
51+
)
4752
from torchgen.gen_functionalization_type import (
4853
gen_functionalization_definition,
4954
gen_functionalization_registration,
@@ -416,14 +421,7 @@ def generate_static_dispatch_backend_call(
416421
f: NativeFunction,
417422
backend_index: BackendIndex,
418423
) -> str:
419-
cpp_sigs = CppSignatureGroup.from_native_function(
420-
f, method=False, fallback_binding=False
421-
)
422-
if sig.symint and f.func.has_symint():
423-
cpp_sig = cpp_sigs.symint_signature
424-
else:
425-
cpp_sig = cpp_sigs.signature
426-
assert cpp_sig is not None
424+
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
427425
name = cpp_sig.name()
428426
exprs = translate_args(sig, cpp_sig)
429427
backend_metadata = backend_index.get_kernel(f)
@@ -2181,6 +2179,7 @@ def gen_source_files(
21812179
selector: SelectiveBuilder,
21822180
static_dispatch_idx: List[BackendIndex],
21832181
backend_indices: Dict[DispatchKey, BackendIndex],
2182+
aoti_fm: FileManager,
21842183
core_fm: FileManager,
21852184
cpu_fm: FileManager,
21862185
cpu_vec_fm: FileManager,
@@ -2350,6 +2349,60 @@ def operator_headers() -> List[str]:
23502349
else:
23512350
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
23522351

2352+
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
2353+
2354+
def get_header(
2355+
f: NativeFunction,
2356+
) -> Optional[str]:
2357+
backend_index = get_backend_index_for_aoti(
2358+
f, dispatch_key, backend_indices
2359+
)
2360+
return (
2361+
None
2362+
if backend_index is None
2363+
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
2364+
)
2365+
2366+
def headers_for_aoti() -> str:
2367+
headers = []
2368+
for g in grouped_native_functions:
2369+
if isinstance(g, NativeFunctionsGroup):
2370+
for f in g.functions():
2371+
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
2372+
header = get_header(f)
2373+
if header is not None:
2374+
headers.append(header)
2375+
else:
2376+
header = get_header(g)
2377+
if header is not None:
2378+
headers.append(header)
2379+
return "\n".join(sorted(set(headers)))
2380+
2381+
extra_headers = (
2382+
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
2383+
)
2384+
2385+
aoti_fm.write(
2386+
f"c_shim_{dispatch_key.lower()}.h",
2387+
lambda: gen_aoti_c_shim(
2388+
native_functions,
2389+
dispatch_key,
2390+
backend_indices,
2391+
header=True,
2392+
includes="",
2393+
),
2394+
)
2395+
aoti_fm.write(
2396+
f"c_shim_{dispatch_key.lower()}.cpp",
2397+
lambda: gen_aoti_c_shim(
2398+
native_functions,
2399+
dispatch_key,
2400+
backend_indices,
2401+
header=False,
2402+
includes=headers_for_aoti() + "\n" + extra_headers,
2403+
),
2404+
)
2405+
23532406
del fm
23542407

23552408
# BackendSelect is generated specially
@@ -2783,6 +2836,9 @@ def main() -> None:
27832836
cpu_vec_fm = make_file_manager(options=options)
27842837
cuda_fm = make_file_manager(options=options)
27852838
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2839+
aoti_fm = make_file_manager(
2840+
options=options, install_dir="torch/csrc/inductor/aoti_torch/generated"
2841+
)
27862842

27872843
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
27882844
# for them; this is the set
@@ -2825,6 +2881,7 @@ def main() -> None:
28252881
selector=selector,
28262882
static_dispatch_idx=static_dispatch_idx,
28272883
backend_indices=backend_indices,
2884+
aoti_fm=aoti_fm,
28282885
core_fm=core_fm,
28292886
cpu_fm=cpu_fm,
28302887
cpu_vec_fm=cpu_vec_fm,

0 commit comments

Comments
 (0)