Skip to content

Commit acc612d

Browse files
[mps] Add cshim for torchao mps ops (#2502)
Co-authored-by: Manuel Candales <[email protected]>
1 parent e4f74be commit acc612d

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
# List of ops and their c-shim declarations used for AOTInductor
4+
# Check out TestUIntxWeightOnlyLinearQuantizer.test_export_accuracy on how to use it
5+
torchao_op_c_shim: dict[torch.ops.OpOverload, list[str]] = {}
6+
7+
for nbit in range(1, 8):
8+
op_name = f"_linear_fp_act_{nbit}bit_weight"
9+
torchao_op_c_shim[getattr(torch.ops.torchao, op_name).default] = [
10+
f"AOTITorchError aoti_torch_mps_{op_name}(AtenTensorHandle A, AtenTensorHandle B, int64_t group_size, AtenTensorHandle S, AtenTensorHandle Z, AtenTensorHandle* ret)",
11+
]

torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// LICENSE file in the root directory of this source tree.
66

77
// clang-format off
8+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
9+
#include <torch/csrc/inductor/aoti_torch/utils.h>
810
#include <torch/library.h>
911
#include <ATen/native/mps/OperationUtils.h>
1012
#include <torchao/experimental/kernels/mps/src/lowbit.h>
@@ -239,3 +241,44 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
239241
}
240242

241243
} // namespace torchao::kernels::mps::lowbit::aten
244+
245+
246+
// c-shim wrappers for AOTInductor
247+
// Check out TestUIntxWeightOnlyLinearQuantizer.test_export_accuracy on how to use it
248+
#define DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(BITS) \
249+
extern "C" { \
250+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__linear_fp_act_##BITS##bit_weight( \
251+
AtenTensorHandle A, \
252+
AtenTensorHandle B, \
253+
int64_t group_size, \
254+
AtenTensorHandle S, \
255+
AtenTensorHandle Z, \
256+
AtenTensorHandle* ret) { \
257+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ \
258+
auto op_handle = \
259+
c10::Dispatcher::singleton() \
260+
.findSchemaOrThrow("torchao::_linear_fp_act_" #BITS "bit_weight", "") \
261+
.typed<at::Tensor( \
262+
const at::Tensor& A, \
263+
const at::Tensor& B, \
264+
int64_t group_size, \
265+
const at::Tensor& S, \
266+
const at::Tensor& Z)>(); \
267+
auto tmp_result = op_handle.call( \
268+
torch::aot_inductor::resolve_tensor_dispatch_flags(A), \
269+
torch::aot_inductor::resolve_tensor_dispatch_flags(B), \
270+
group_size, \
271+
torch::aot_inductor::resolve_tensor_dispatch_flags(S), \
272+
torch::aot_inductor::resolve_tensor_dispatch_flags(Z)); \
273+
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result)); \
274+
}); \
275+
} \
276+
}
277+
278+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(1)
279+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(2)
280+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(3)
281+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(4)
282+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(5)
283+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(6)
284+
DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(7)

0 commit comments

Comments
 (0)