Skip to content

Commit fca457b

Browse files
Revert "Add torch._scaled_mm for CPU (pytorch#139975)"
This reverts commit 3f80632. Reverted pytorch#139975 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing some tests in trunk ([comment](pytorch#139975 (comment)))
1 parent 0f474a9 commit fca457b

File tree

11 files changed

+573
-905
lines changed

11 files changed

+573
-905
lines changed

aten/src/ATen/native/Blas.cpp

-83
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
#include <ATen/Config.h>
88

99
#include <ATen/native/mkldnn/Matmul.h>
10-
#include <ATen/native/mkldnn/Linear.h>
11-
#include <ATen/native/Resize.h>
12-
#if !defined(__s390x__) && !defined(__powerpc__)
13-
#include <cpuinfo.h>
14-
#endif
1510

1611
#ifndef AT_PER_OPERATOR_HEADERS
1712
#include <ATen/CPUFunctions.h>
@@ -29,9 +24,6 @@
2924
#include <ATen/ops/mv_native.h>
3025
#include <ATen/ops/scalar_tensor_native.h>
3126
#include <ATen/ops/vdot_native.h>
32-
#include <ATen/ops/_scaled_mm_native.h>
33-
#include <ATen/ops/mul.h>
34-
#include <ATen/ops/matmul.h>
3527
#endif
3628

3729
namespace at::meta {
@@ -230,79 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){
230222

231223
}
232224

233-
static Tensor&
234-
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
235-
const Tensor& scale_a,
236-
const Tensor& scale_b,
237-
const std::optional<at::Tensor>& bias,
238-
const std::optional<at::Tensor>& scale_result,
239-
std::optional<c10::ScalarType> out_dtype,
240-
bool use_fast_accum,
241-
Tensor& out) {
242-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
243-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
244-
TORCH_CHECK(
245-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
246-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
247-
248-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
249-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
250-
" but got ", bias->numel());
251-
252-
// Check types
253-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
254-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
255-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
256-
257-
auto mat1_c = mat1.contiguous();
258-
auto mat2_c = mat2.contiguous();
259-
IntArrayRef mat1_sizes = mat1_c.sizes();
260-
IntArrayRef mat2_sizes = mat2_c.sizes();
261-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
262-
263-
float input_scale = scale_a.item<float>();
264-
float weight_scale = scale_b.item<float>();
265-
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
266-
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
267-
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
268-
if (bias) {
269-
out_tmp.add_(bias.value());
270-
}
271-
out_tmp = out_tmp.to(out.scalar_type());
272-
out.copy_(out_tmp);
273-
return out;
274-
}
275-
276-
Tensor&
277-
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
278-
const Tensor& scale_a,
279-
const Tensor& scale_b,
280-
const std::optional<at::Tensor>& bias,
281-
const std::optional<at::Tensor>& scale_result,
282-
std::optional<c10::ScalarType> out_dtype,
283-
bool use_fast_accum,
284-
Tensor& out) {
285-
#if AT_MKLDNN_ENABLED()
286-
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
287-
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
288-
} else
289-
#endif
290-
{
291-
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
292-
}
293-
}
294-
295-
Tensor
296-
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
297-
const Tensor& scale_a,
298-
const Tensor& scale_b,
299-
const std::optional<at::Tensor>& bias,
300-
const std::optional<at::Tensor>& scale_result,
301-
std::optional<c10::ScalarType> out_dtype,
302-
bool use_fast_accum) {
303-
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
304-
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
305-
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
306-
}
307-
308225
} // namespace at::native

aten/src/ATen/native/mkldnn/Linear.cpp

-126
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <ATen/core/Tensor.h>
55
#include <torch/library.h>
66
#include <ATen/native/mkldnn/Linear.h>
7-
#include <ATen/native/Resize.h>
87

98
#ifndef AT_PER_OPERATOR_HEADERS
109
#include <ATen/Functions.h>
@@ -47,18 +46,6 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
4746
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
4847
}
4948

50-
Tensor&
51-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
52-
const Tensor& scale_a,
53-
const Tensor& scale_b,
54-
const std::optional<at::Tensor>& bias,
55-
const std::optional<at::Tensor>& scale_result,
56-
std::optional<c10::ScalarType> out_dtype,
57-
bool use_fast_accum,
58-
Tensor& out) {
59-
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
60-
}
61-
6249
} // namespace native
6350
} // namespace at
6451

@@ -460,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
460447
TORCH_FN(mkldnn_linear_pointwise_binary));
461448
}
462449

463-
Tensor&
464-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
465-
const Tensor& scale_a,
466-
const Tensor& scale_b,
467-
const std::optional<at::Tensor>& bias,
468-
const std::optional<at::Tensor>& scale_result,
469-
std::optional<c10::ScalarType> out_dtype,
470-
bool use_fast_accum,
471-
Tensor& out) {
472-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
473-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
474-
TORCH_CHECK(
475-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
476-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
477-
478-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
479-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
480-
" but got ", bias->numel());
481-
482-
// Check types
483-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
484-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
485-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
486-
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
487-
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
488-
489-
// Validation checks have passed lets resize the output to actual size
490-
auto mat1_c = mat1.contiguous();
491-
auto mat2_c = mat2.contiguous();
492-
IntArrayRef mat1_sizes = mat1_c.sizes();
493-
IntArrayRef mat2_sizes = mat2_c.sizes();
494-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
495-
496-
float input_scale = scale_a.item<float>();
497-
float weight_scale = scale_b.item<float>();
498-
auto src = at::native::itensor_view_from_dense(mat1_c);
499-
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
500-
bool with_bias = bias.has_value();
501-
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
502-
N = mat2_sizes[1];
503-
504-
std::vector<int64_t> src_dims = {M, K};
505-
std::vector<int64_t> weight_dims = {K, N};
506-
std::vector<int64_t> dst_dims = {M, N};
507-
508-
ideep::tensor dst = at::native::itensor_view_from_dense(out);
509-
auto src_desc = ideep::tensor::desc(
510-
src_dims,
511-
get_mkldnn_dtype(mat1.scalar_type()),
512-
ideep::format_tag::any);
513-
auto weights_desc = ideep::tensor::desc(
514-
weight_dims,
515-
get_mkldnn_dtype(mat2.scalar_type()),
516-
ideep::format_tag::any);
517-
auto dst_desc = ideep::tensor::desc(
518-
dst_dims,
519-
get_mkldnn_dtype(out.scalar_type()),
520-
ideep::format_tag::any);
521-
ideep::tensor onednn_bias;
522-
if (with_bias) {
523-
auto bias_value = bias.value();
524-
if (bias_value.dim() == 1) {
525-
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
526-
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
527-
} else {
528-
onednn_bias = at::native::itensor_view_from_dense(bias_value);
529-
}
530-
}
531-
auto bias_desc = ideep::tensor::desc();
532-
if (with_bias) {
533-
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
534-
get_mkldnn_dtype(bias.value().scalar_type()),
535-
ideep::format_tag::any);
536-
}
537-
auto op_attr = ideep::attr_t();
538-
if (input_scale != 1.0f) {
539-
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
540-
}
541-
if (weight_scale != 1.0f) {
542-
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
543-
}
544-
545-
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
546-
auto engine = ideep::engine::cpu_engine();
547-
dnnl::matmul::primitive_desc primitive_desc = with_bias
548-
? dnnl::matmul::primitive_desc(
549-
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
550-
: dnnl::matmul::primitive_desc(
551-
engine, src_desc, weights_desc, dst_desc, op_attr);
552-
auto primitive = dnnl::matmul(primitive_desc);
553-
554-
// Prepare args and execute primitive
555-
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
556-
ideep::exec_args args;
557-
args.insert({DNNL_ARG_SRC, src});
558-
args.insert({DNNL_ARG_WEIGHTS, weight_t});
559-
args.insert({DNNL_ARG_DST, dst});
560-
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
561-
if (with_bias) {
562-
args.insert({DNNL_ARG_BIAS, onednn_bias});
563-
}
564-
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
565-
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
566-
567-
if (input_scale != 1.0f) {
568-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
569-
}
570-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
571-
572-
primitive.execute(ideep::stream::default_stream(), args);
573-
return out;
574-
}
575-
576450
} // namespace at
577451

578452
#endif // AT_MKLDNN_ENABLED

aten/src/ATen/native/mkldnn/Linear.h

-12
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,3 @@ C10_API Tensor mkl_linear(
3535
} // namespace at
3636

3737
#endif // AT_MKLDNN_ENABLED()
38-
39-
namespace at::native {
40-
Tensor&
41-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
42-
const Tensor& scale_a,
43-
const Tensor& scale_b,
44-
const std::optional<at::Tensor>& bias,
45-
const std::optional<at::Tensor>& scale_result,
46-
std::optional<c10::ScalarType> out_dtype,
47-
bool use_fast_accum,
48-
Tensor& out);
49-
} // namespace at::native

aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp

+1-21
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
5757
return ideep::tensor::data_type::bf16;
5858
case ScalarType::Half:
5959
return ideep::tensor::data_type::f16;
60-
case ScalarType::Float8_e4m3fn:
61-
return ideep::tensor::data_type::f8_e4m3;
62-
case ScalarType::Float8_e5m2:
63-
return ideep::tensor::data_type::f8_e5m2;
6460
default:
6561
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
6662
}
@@ -165,24 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
165161
const_cast<void*>(tensor.const_data_ptr()) :
166162
tensor.data_ptr()};
167163
}
168-
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
169-
return {{tensor.sizes().vec(),
170-
ideep::tensor::data_type::f8_e4m3,
171-
tensor.strides().vec()},
172-
from_const_data_ptr ?
173-
const_cast<void*>(tensor.const_data_ptr()) :
174-
tensor.data_ptr()};
175-
}
176-
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
177-
return {{tensor.sizes().vec(),
178-
ideep::tensor::data_type::f8_e5m2,
179-
tensor.strides().vec()},
180-
from_const_data_ptr ?
181-
const_cast<void*>(tensor.const_data_ptr()) :
182-
tensor.data_ptr()};
183-
}
184164
else {
185-
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
165+
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
186166
}
187167
}
188168

aten/src/ATen/native/native_functions.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -7071,13 +7071,11 @@
70717071
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
70727072
variants: function
70737073
dispatch:
7074-
CPU: _scaled_mm_cpu
70757074
CUDA: _scaled_mm_cuda
70767075

70777076
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
70787077
variants: function
70797078
dispatch:
7080-
CPU: _scaled_mm_out_cpu
70817079
CUDA: _scaled_mm_out_cuda
70827080

70837081
# NOTE [ Sparse: autograd and API ]

0 commit comments

Comments
 (0)