Skip to content

Commit 178ce14

Browse files
int3pytorchmergebot
authored andcommitted
Hoist out auxiliary values in optional-typed arguments (pytorch#123613)
This fixes pytorch#123176, and partially addresses pytorch#121814 too. pytorch#123176 uses an optional device arg while pytorch#121814 uses an optional list arg. For optional arguments that have auxiliary info -- specifically, tuples / lists with their length parameter, and device types with their device index -- we need to hoist out the extra argument. E.g. when passing a device with ID 1, we want to emit ``` auto var_0 = cached_torch_device_type_cpu; aoti_torch_foo(..., &var_0, 1); ``` instead of the (syntactically incorrect) ``` auto var_0 = cached_torch_device_type_cpu,1; aoti_torch_foo(..., &var_0); ``` Pull Request resolved: pytorch#123613 Approved by: https://github.com/desertfire
1 parent 1970a80 commit 178ce14

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

test/inductor/test_cpu_cpp_wrapper.py

-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase):
8888
"test_qlinear_cpu",
8989
"test_qlinear_dequant_promotion_cpu",
9090
"test_qlinear_relu_cpu",
91-
"test_randn_with_dtype_and_device_cpu",
9291
"test_scatter5_cpu",
9392
"test_scatter6_cpu",
9493
"test_tensor2_cpu",

torch/_inductor/codegen/cpp_wrapper_cpu.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,7 @@ def generate_inf_and_nan_checker(self, nodes):
14031403
def codegen_device(self, device):
14041404
if config.abi_compatible:
14051405
self.used_cached_devices.add(device.type)
1406-
return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}"
1406+
return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}"
14071407
else:
14081408
from .cpp import DEVICE_TO_ATEN
14091409

@@ -2093,8 +2093,22 @@ def val_to_cpp_arg_str(self, type_, val) -> str:
20932093
return "0" # nullptr is not available in C
20942094
if not isinstance(type_.getElementType(), torch.TensorType):
20952095
var_name = f"var_{next(self.arg_var_id)}"
2096-
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
2097-
return f"&{var_name}"
2096+
if isinstance(
2097+
type_.getElementType(),
2098+
(torch.ListType, torch.TupleType, torch.DeviceObjType),
2099+
):
2100+
arg_str = self.val_to_arg_str(val)
2101+
if val is None:
2102+
return "{arg_str}, 0"
2103+
else:
2104+
# For datatypes with auxiliary info, we need to hoist out the extra arguments.
2105+
# NOTE: This only works if there is one additional argument, though it can easily be generalized.
2106+
main_value, aux = arg_str.rsplit(", ")
2107+
self.writeline(f"auto {var_name} = {main_value};")
2108+
return f"&{var_name}, {aux}"
2109+
else:
2110+
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
2111+
return f"&{var_name}"
20982112
elif config.c_shim_version == "2":
20992113
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
21002114
base_handle = self.val_to_arg_str(val)

torch/csrc/inductor/aoti_torch/utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
141141
return result;
142142
}
143143

144-
// utility functions to convert a pointer to a list of optional values
144+
// Utility function to convert a pointer to an optional list of values
145145
template <class T, class U>
146146
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
147147
U** ptr,

0 commit comments

Comments
 (0)