|
5 | 5 | // LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | // clang-format off |
| 8 | +#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 9 | +#include <torch/csrc/inductor/aoti_torch/utils.h> |
8 | 10 | #include <torch/library.h> |
9 | 11 | #include <ATen/native/mps/OperationUtils.h> |
10 | 12 | #include <torchao/experimental/kernels/mps/src/lowbit.h> |
@@ -239,3 +241,44 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { |
239 | 241 | } |
240 | 242 |
|
241 | 243 | } // namespace torchao::kernels::mps::lowbit::aten |
| 244 | + |
| 245 | + |
| 246 | +// c-shim wrappers for AOTInductor |
| 247 | +// Check out TestUIntxWeightOnlyLinearQuantizer.test_export_accuracy on how to use it |
| 248 | +#define DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(BITS) \ |
| 249 | +extern "C" { \ |
| 250 | + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__linear_fp_act_##BITS##bit_weight( \ |
| 251 | + AtenTensorHandle A, \ |
| 252 | + AtenTensorHandle B, \ |
| 253 | + int64_t group_size, \ |
| 254 | + AtenTensorHandle S, \ |
| 255 | + AtenTensorHandle Z, \ |
| 256 | + AtenTensorHandle* ret) { \ |
| 257 | + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ \ |
| 258 | + auto op_handle = \ |
| 259 | + c10::Dispatcher::singleton() \ |
| 260 | + .findSchemaOrThrow("torchao::_linear_fp_act_" #BITS "bit_weight", "") \ |
| 261 | + .typed<at::Tensor( \ |
| 262 | + const at::Tensor& A, \ |
| 263 | + const at::Tensor& B, \ |
| 264 | + int64_t group_size, \ |
| 265 | + const at::Tensor& S, \ |
| 266 | + const at::Tensor& Z)>(); \ |
| 267 | + auto tmp_result = op_handle.call( \ |
| 268 | + torch::aot_inductor::resolve_tensor_dispatch_flags(A), \ |
| 269 | + torch::aot_inductor::resolve_tensor_dispatch_flags(B), \ |
| 270 | + group_size, \ |
| 271 | + torch::aot_inductor::resolve_tensor_dispatch_flags(S), \ |
| 272 | + torch::aot_inductor::resolve_tensor_dispatch_flags(Z)); \ |
| 273 | + *ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result)); \ |
| 274 | + }); \ |
| 275 | + } \ |
| 276 | +} |
| 277 | + |
| 278 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(1) |
| 279 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(2) |
| 280 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(3) |
| 281 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(4) |
| 282 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(5) |
| 283 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(6) |
| 284 | +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(7) |
0 commit comments