Skip to content

Commit bc77f24

Browse files
authored
Zmz/matmul12 (DeepLink-org#869)
* support matmul
1 parent 03056c2 commit bc77f24

File tree

5 files changed

+50
-2
lines changed

5 files changed

+50
-2
lines changed

impl/ascend/convert_config.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@
283283
- diopiHardswishBackward:
284284
dtype: (float64)->float32
285285

286+
- diopiMatmul:
287+
dtype: (float64)->float32
288+
286289
- diopiAtan:
287290
dtype: (uint8, int8, int32, int16, int64, bool)->float32
288291

impl/ascend/device_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@
431431
args=[
432432
{
433433
"ins": ['input'],
434-
"shape": [Skip((128, 49, 128)),Skip((5,)),Skip((128, 4, 49, 32)),Skip((2, 1, 3136, 3136)),Skip((2, 784, 64)),Skip((2, 16, 8, 64)),Skip((2, 31, 6, 40, 512)),],
434+
"shape": [Skip((2, 31, 6, 40, 512)),],
435435
},
436436
]
437437
),

impl/ascend_npu/ascend_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ ascend:
200200
- diopiScatterScalar
201201
- diopiScatterInpScalar
202202
ascend_npu:
203+
- diopiMatmul
203204
- diopiCastDtype
204205
- diopiCopyInp
205206
- diopiCat

impl/ascend_npu/diopi_impl/matmul.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/**
2+
* @file
3+
* @author DeepLink
4+
* @copyright (c) 2023, DeepLink.
5+
*/
6+
7+
#include "helper.hpp"
8+
#include "op_plugin/AclOpsInterface.h"
9+
10+
namespace OP_IMPL_NS {
11+
12+
diopiError_t diopiMatmul(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
13+
BEGIN_CALL_ACL_OP(input, out, other);
14+
acl_op::matmul_out(inputAt, otherAt, outAt);
15+
END_CALL_ACL_OP();
16+
}
17+
18+
} // namespace OP_IMPL_NS

impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -2895,7 +2895,24 @@ at::Tensor viewStorage(const at::Tensor input, const c10::IntArrayRef sizes, con
28952895
if (st != -1) st *= sizes[i - 1];
28962896
}
28972897
}
2898-
return fromPreAllocated(input.data_ptr() + storageOffset * input.itemsize(), sizes, stridesVec, input.options());
2898+
2899+
// when shape[0]=-1, fill data
2900+
std::vector<int64_t> sizeVec(sizes.size(), 1);
2901+
std::copy(sizes.begin(), sizes.end(), sizeVec.begin());
2902+
if (!sizes.empty() && sizes[0] == -1) {
2903+
bool flag = true;
2904+
for (auto i : sizes) {
2905+
if (!flag && i < 0) {
2906+
TORCH_CHECK(false, "more than one -1, sizes=", sizes);
2907+
}
2908+
if (i < 0) {
2909+
flag = false;
2910+
}
2911+
}
2912+
int count = std::accumulate(sizeVec.begin() + 1, sizeVec.end(), 1, std::multiplies<int>());
2913+
sizeVec[0] = input.numel() / count;
2914+
}
2915+
return fromPreAllocated(input.data_ptr() + storageOffset * input.itemsize(), sizeVec, stridesVec, input.options());
28992916
}
29002917

29012918
c10::List<c10::optional<at::Tensor>> castIntIndicesToLongIndices(const c10::List<c10::optional<at::Tensor>>& indices) {
@@ -3057,7 +3074,11 @@ at::Tensor wrapper__transpose(const at::Tensor& self, int64_t dim0, int64_t dim1
30573074
}
30583075

30593076
at::Scalar wrapper___local_scalar_dense(const at::Tensor& self) { return at_npu::native::NPUNativeFunctions::_local_scalar_dense(self); }
3077+
at::Tensor& wrapper_out_mm_out(const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::mm_out(self, mat2, out); }
30603078

3079+
at::Tensor& wrapper_source_Tensor_set_(at::Tensor& self, const at::Tensor& source) { return at_npu::native::NPUNativeFunctions::set_(self, source); }
3080+
at::Tensor& wrapper_out_bmm_out(const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::bmm_out(self, mat2, out); }
3081+
at::Tensor wrapper__dot(const at::Tensor& self, const at::Tensor& tensor) { return acl_op::dot(self, tensor); }
30613082
} // namespace
30623083

30633084
namespace at {
@@ -3092,6 +3113,11 @@ TORCH_LIBRARY_IMPL(aten, XLA, m) {
30923113
m.impl("repeat", TORCH_FN(wrapper__repeat));
30933114
m.impl("transpose.int", TORCH_FN(wrapper__transpose));
30943115
m.impl("_local_scalar_dense", TORCH_FN(wrapper___local_scalar_dense));
3116+
m.impl("cat", TORCH_FN(wrapper__cat));
3117+
m.impl("mm.out", TORCH_FN(wrapper_out_mm_out));
3118+
m.impl("set_.source_Tensor", TORCH_FN(wrapper_source_Tensor_set_));
3119+
m.impl("dot", TORCH_FN(wrapper__dot));
3120+
m.impl("bmm.out", TORCH_FN(wrapper_out_bmm_out));
30953121
};
30963122

30973123
TORCH_LIBRARY_IMPL(_, XLA, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&ascend_diopi_fallback>()); }

0 commit comments

Comments
 (0)