|
4 | 4 | #include <ATen/core/Tensor.h>
|
5 | 5 | #include <torch/library.h>
|
6 | 6 | #include <ATen/native/mkldnn/Linear.h>
|
7 |
| -#include <ATen/native/Resize.h> |
8 | 7 |
|
9 | 8 | #ifndef AT_PER_OPERATOR_HEADERS
|
10 | 9 | #include <ATen/Functions.h>
|
@@ -47,18 +46,6 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
|
47 | 46 | TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
48 | 47 | }
|
49 | 48 |
|
50 |
| -Tensor& |
51 |
| -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, |
52 |
| - const Tensor& scale_a, |
53 |
| - const Tensor& scale_b, |
54 |
| - const std::optional<at::Tensor>& bias, |
55 |
| - const std::optional<at::Tensor>& scale_result, |
56 |
| - std::optional<c10::ScalarType> out_dtype, |
57 |
| - bool use_fast_accum, |
58 |
| - Tensor& out) { |
59 |
| - TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support"); |
60 |
| -} |
61 |
| - |
62 | 49 | } // namespace native
|
63 | 50 | } // namespace at
|
64 | 51 |
|
@@ -460,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
460 | 447 | TORCH_FN(mkldnn_linear_pointwise_binary));
|
461 | 448 | }
|
462 | 449 |
|
463 |
| -Tensor& |
464 |
| -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, |
465 |
| - const Tensor& scale_a, |
466 |
| - const Tensor& scale_b, |
467 |
| - const std::optional<at::Tensor>& bias, |
468 |
| - const std::optional<at::Tensor>& scale_result, |
469 |
| - std::optional<c10::ScalarType> out_dtype, |
470 |
| - bool use_fast_accum, |
471 |
| - Tensor& out) { |
472 |
| - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); |
473 |
| - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); |
474 |
| - TORCH_CHECK( |
475 |
| - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", |
476 |
| - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); |
477 |
| - |
478 |
| - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); |
479 |
| - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], |
480 |
| - " but got ", bias->numel()); |
481 |
| - |
482 |
| - // Check types |
483 |
| - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); |
484 |
| - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); |
485 |
| - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); |
486 |
| - // TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6. |
487 |
| - TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type"); |
488 |
| - |
489 |
| - // Validation checks have passed lets resize the output to actual size |
490 |
| - auto mat1_c = mat1.contiguous(); |
491 |
| - auto mat2_c = mat2.contiguous(); |
492 |
| - IntArrayRef mat1_sizes = mat1_c.sizes(); |
493 |
| - IntArrayRef mat2_sizes = mat2_c.sizes(); |
494 |
| - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); |
495 |
| - |
496 |
| - float input_scale = scale_a.item<float>(); |
497 |
| - float weight_scale = scale_b.item<float>(); |
498 |
| - auto src = at::native::itensor_view_from_dense(mat1_c); |
499 |
| - auto weight_t = at::native::itensor_view_from_dense(mat2_c); |
500 |
| - bool with_bias = bias.has_value(); |
501 |
| - int64_t K = mat1_sizes[1], M = mat1_sizes[0], |
502 |
| - N = mat2_sizes[1]; |
503 |
| - |
504 |
| - std::vector<int64_t> src_dims = {M, K}; |
505 |
| - std::vector<int64_t> weight_dims = {K, N}; |
506 |
| - std::vector<int64_t> dst_dims = {M, N}; |
507 |
| - |
508 |
| - ideep::tensor dst = at::native::itensor_view_from_dense(out); |
509 |
| - auto src_desc = ideep::tensor::desc( |
510 |
| - src_dims, |
511 |
| - get_mkldnn_dtype(mat1.scalar_type()), |
512 |
| - ideep::format_tag::any); |
513 |
| - auto weights_desc = ideep::tensor::desc( |
514 |
| - weight_dims, |
515 |
| - get_mkldnn_dtype(mat2.scalar_type()), |
516 |
| - ideep::format_tag::any); |
517 |
| - auto dst_desc = ideep::tensor::desc( |
518 |
| - dst_dims, |
519 |
| - get_mkldnn_dtype(out.scalar_type()), |
520 |
| - ideep::format_tag::any); |
521 |
| - ideep::tensor onednn_bias; |
522 |
| - if (with_bias) { |
523 |
| - auto bias_value = bias.value(); |
524 |
| - if (bias_value.dim() == 1) { |
525 |
| - auto b_reshape = bias_value.reshape({1, bias_value.size(0)}); |
526 |
| - onednn_bias = at::native::itensor_view_from_dense(b_reshape); |
527 |
| - } else { |
528 |
| - onednn_bias = at::native::itensor_view_from_dense(bias_value); |
529 |
| - } |
530 |
| - } |
531 |
| - auto bias_desc = ideep::tensor::desc(); |
532 |
| - if (with_bias) { |
533 |
| - bias_desc = ideep::tensor::desc(onednn_bias.get_dims(), |
534 |
| - get_mkldnn_dtype(bias.value().scalar_type()), |
535 |
| - ideep::format_tag::any); |
536 |
| - } |
537 |
| - auto op_attr = ideep::attr_t(); |
538 |
| - if (input_scale != 1.0f) { |
539 |
| - op_attr.set_scales_mask(DNNL_ARG_SRC, 0); |
540 |
| - } |
541 |
| - if (weight_scale != 1.0f) { |
542 |
| - op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); |
543 |
| - } |
544 |
| - |
545 |
| - op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |
546 |
| - auto engine = ideep::engine::cpu_engine(); |
547 |
| - dnnl::matmul::primitive_desc primitive_desc = with_bias |
548 |
| - ? dnnl::matmul::primitive_desc( |
549 |
| - engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) |
550 |
| - : dnnl::matmul::primitive_desc( |
551 |
| - engine, src_desc, weights_desc, dst_desc, op_attr); |
552 |
| - auto primitive = dnnl::matmul(primitive_desc); |
553 |
| - |
554 |
| - // Prepare args and execute primitive |
555 |
| - ideep::tensor scratchpad(primitive_desc.scratchpad_desc()); |
556 |
| - ideep::exec_args args; |
557 |
| - args.insert({DNNL_ARG_SRC, src}); |
558 |
| - args.insert({DNNL_ARG_WEIGHTS, weight_t}); |
559 |
| - args.insert({DNNL_ARG_DST, dst}); |
560 |
| - args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); |
561 |
| - if (with_bias) { |
562 |
| - args.insert({DNNL_ARG_BIAS, onednn_bias}); |
563 |
| - } |
564 |
| - ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale)); |
565 |
| - ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale)); |
566 |
| - |
567 |
| - if (input_scale != 1.0f) { |
568 |
| - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t}); |
569 |
| - } |
570 |
| - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t}); |
571 |
| - |
572 |
| - primitive.execute(ideep::stream::default_stream(), args); |
573 |
| - return out; |
574 |
| -} |
575 |
| - |
576 | 450 | } // namespace at
|
577 | 451 |
|
578 | 452 | #endif // AT_MKLDNN_ENABLED
|
0 commit comments