Skip to content

Commit 889c520

Browse files
authored
tyf/support nondense (DeepLink-org#860)
* support more dense op * support non dense
1 parent cfb0dd5 commit 889c520

File tree

8 files changed

+295
-49
lines changed

8 files changed

+295
-49
lines changed

adaptor/codegen/gen.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
str_to_diopi_format = {
3535
"NCHW": "diopiMemoryFormat_t::Contiguous",
3636
"NCL": "diopiMemoryFormat_t::Contiguous",
37+
"NCDHW": "diopiMemoryFormat_t::Contiguous",
3738
"NLC": "diopiMemoryFormat_t::ChannelsLast1d",
3839
"NHWC": "diopiMemoryFormat_t::ChannelsLast",
3940
"NDHWC": "diopiMemoryFormat_t::ChannelsLast3d",
41+
"UD": "diopiMemoryFormat_t::Preserve",
4042
}
4143

4244

@@ -441,6 +443,7 @@ def analysis_configs(config: List[dict], funcs_info: dict) -> dict:
441443
or layout == "NCL"
442444
or layout == "NDHWC"
443445
or layout == "NCDHW"
446+
or layout == "UD"
444447
):
445448
op_layouts.append(layout)
446449
else:
@@ -821,4 +824,4 @@ def gen_all_codes() -> None:
821824

822825

823826
if __name__ == "__main__":
824-
gen_all_codes()
827+
gen_all_codes()

adaptor/csrc/convert.cpp

+33-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
11
#include "convert.hpp"
22

3+
bool denseCheckAdaptor(diopiSize_t shape, diopiSize_t stride) {
4+
int dim = shape.len;
5+
std::vector<std::pair<int64_t, int64_t>> stridesSizes(dim, std::pair<int64_t, int64_t>(1, 1));
6+
7+
for (int i = 0; i < dim; i++) {
8+
stridesSizes[i] = std::pair<int64_t, int64_t>(stride.data[i], shape.data[i]);
9+
10+
if (stride.data[i] == 0 || shape.data[i] == 0) {
11+
return false;
12+
}
13+
}
14+
15+
sort(stridesSizes.begin(), stridesSizes.end(), [](std::pair<int64_t, int64_t> a, std::pair<int64_t, int64_t> b) { return a.first < b.first; });
16+
// e.g. shape = 2,3,4,5,stride = 1,2,6,24 pass
17+
// e.g. shape = 2,3,4,5, stride = 1,2,6,12 should not pass
18+
int cur = 1;
19+
for (int i = 0; i < dim; i++) {
20+
if (stridesSizes[i].first != cur) {
21+
return false;
22+
}
23+
cur *= stridesSizes[i].second;
24+
}
25+
return true;
26+
}
27+
328
std::vector<int64_t> calcStrides(diopiSize_t size, diopiMemoryFormat_t format) {
429
size_t ndims = size.len;
530
std::vector<int64_t> strides(ndims);
@@ -104,7 +129,7 @@ bool isContiguous(diopiSize_t size, diopiSize_t strideDiopi, diopiMemoryFormat_t
104129

105130
if (format == diopiMemoryFormat_t::Contiguous) {
106131
for (int64_t i = dim - 1; i >= 0; i--) {
107-
const auto &shapeD = shape[i];
132+
const auto& shapeD = shape[i];
108133
if (shapeD == 0) {
109134
return true;
110135
}
@@ -117,8 +142,8 @@ bool isContiguous(diopiSize_t size, diopiSize_t strideDiopi, diopiMemoryFormat_t
117142
}
118143
} else if (format == diopiMemoryFormat_t::ChannelsLast) {
119144
if (dim != 4) return false;
120-
for (auto &i : {1, 3, 2, 0}) {
121-
const auto &shapeD = shape[i];
145+
for (auto& i : {1, 3, 2, 0}) {
146+
const auto& shapeD = shape[i];
122147
if (shapeD != 1) {
123148
// shape_d != 1 help dealing with shape like [2, 2048, 1, 1]
124149
if (strides[i] != stride) {
@@ -129,8 +154,8 @@ bool isContiguous(diopiSize_t size, diopiSize_t strideDiopi, diopiMemoryFormat_t
129154
}
130155
} else if (format == diopiMemoryFormat_t::ChannelsLast3d) {
131156
if (dim != 5) return false;
132-
for (auto &i : {1, 4, 3, 2, 0}) {
133-
const auto &shapeD = shape[i];
157+
for (auto& i : {1, 4, 3, 2, 0}) {
158+
const auto& shapeD = shape[i];
134159
if (shapeD != 1) {
135160
if (strides[i] != stride) {
136161
return false;
@@ -140,8 +165,8 @@ bool isContiguous(diopiSize_t size, diopiSize_t strideDiopi, diopiMemoryFormat_t
140165
}
141166
} else if (format == diopiMemoryFormat_t::ChannelsLast1d) {
142167
if (dim != 3) return false;
143-
for (auto &i : {1, 2, 0}) {
144-
const auto &shapeD = shape[i];
168+
for (auto& i : {1, 2, 0}) {
169+
const auto& shapeD = shape[i];
145170
if (shapeD != 1) {
146171
if (strides[i] != stride) {
147172
return false;
@@ -186,7 +211,7 @@ std::vector<diopiMemoryFormat_t> setIntersection(std::vector<diopiMemoryFormat_t
186211
case err: \
187212
return #err;
188213

189-
const char *getDiopiErrorStr(diopiError_t err) {
214+
const char* getDiopiErrorStr(diopiError_t err) {
190215
switch (err) {
191216
DIOPI_ERROR_TO_STR(diopiErrorOccurred)
192217
DIOPI_ERROR_TO_STR(diopiNotInited)

adaptor/csrc/convert.hpp

+19-29
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <ostream>
1414
#include <vector>
1515

16+
bool denseCheckAdaptor(diopiSize_t shape, diopiSize_t stride);
17+
1618
std::vector<int64_t> calcStrides(diopiSize_t size, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous);
1719

1820
bool isLikeChannelsLast(diopiConstTensorHandle_t tensor, bool checkContiguous, diopiMemoryFormat_t format = diopiMemoryFormat_t::ChannelsLast);
@@ -52,7 +54,7 @@ struct RemoveConst<diopiConstTensorHandle_t> {
5254

5355
class NoCast {
5456
public:
55-
static bool getDstDtype(diopiDtype_t srcDtype, diopiDtype_t &targetDtype) {
57+
static bool getDstDtype(diopiDtype_t srcDtype, diopiDtype_t& targetDtype) {
5658
bool convert = false;
5759
switch (srcDtype) {
5860
default:
@@ -63,7 +65,7 @@ class NoCast {
6365
};
6466

6567
template <class T, class strategy = NoCast>
66-
ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
68+
ConvertType castImpl(diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
6769
ConvertType convertType;
6870
if (!src) {
6971
*dst = src;
@@ -77,13 +79,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
7779
strategy::getDstDtype(srcDtype, dstDtype);
7880
std::vector<diopiMemoryFormat_t> targetMemoryFormats = obtainTargetMemoryFormats(srcSize.len, supportMemoryFormats);
7981
diopiTensorHandle_t memoryFormatedTensor = nullptr;
80-
8182
// convertDtype
8283

8384
diopiDevice_t device;
8485
diopiGetTensorDevice(src, &device);
8586
diopiTensorHandle_t tmp0 = nullptr;
8687
bool needConvertDtype = srcDtype != dstDtype;
88+
8789
if (needConvertDtype) {
8890
diopiRequireTensor(ctx, &tmp0, &srcSize, &srcStride, dstDtype, device);
8991
diopiCastDtype(ctx, tmp0, src);
@@ -108,6 +110,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
108110
}
109111
diopiSize_t dstStride = srcStride;
110112
diopiSize_t dstSize = srcSize;
113+
if (!targetMemoryFormats.empty()) {
114+
if (!denseCheckAdaptor(srcSize, srcStride) && supportMemoryFormats[0] == diopiMemoryFormat_t::Preserve) {
115+
targetMemoryFormats.push_back(diopiMemoryFormat_t::Preserve);
116+
needConvertMemoryFormat = true;
117+
}
118+
}
119+
111120
if (needConvertMemoryFormat) {
112121
diopiContiguous(ctx, &memoryFormatedTensor, tmp0, targetMemoryFormats[0]);
113122
convertType.setMemoryFormatConverted();
@@ -122,7 +131,7 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
122131
}
123132

124133
template <class T, class strategy>
125-
ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
134+
ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
126135
ConvertType convertType;
127136
if (!src) {
128137
*dst = src;
@@ -139,6 +148,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
139148
if (targetMemoryFormats.empty()) {
140149
needConvertMemoryFormat = false;
141150
}
151+
142152
for (auto memoryFormat : targetMemoryFormats) {
143153
if (isContiguous(srcSize, srcStride, memoryFormat)) {
144154
needConvertMemoryFormat = false;
@@ -174,7 +184,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
174184
}
175185

176186
template <typename Adaptor, typename... Args>
177-
void dispatchDiopi(diopiContextHandle_t ctx, Args &&...args) {
187+
void dispatchDiopi(diopiContextHandle_t ctx, Args&&... args) {
178188
auto adaptor = Adaptor();
179189
adaptor(ctx, std::forward<Args>(args)...);
180190
}
@@ -195,10 +205,10 @@ template <class strategy = NoCast>
195205
class DiopiTensorWrapper {
196206
public:
197207
// forbid copy/move constructor/assignment
198-
DiopiTensorWrapper(const DiopiTensorWrapper &) = delete;
199-
DiopiTensorWrapper &operator=(const DiopiTensorWrapper &) = delete;
200-
DiopiTensorWrapper(DiopiTensorWrapper &&) = delete;
201-
DiopiTensorWrapper &operator=(DiopiTensorWrapper &&) = delete;
208+
DiopiTensorWrapper(const DiopiTensorWrapper&) = delete;
209+
DiopiTensorWrapper& operator=(const DiopiTensorWrapper&) = delete;
210+
DiopiTensorWrapper(DiopiTensorWrapper&&) = delete;
211+
DiopiTensorWrapper& operator=(DiopiTensorWrapper&&) = delete;
202212

203213
private:
204214
diopiContextHandle_t ctx_;
@@ -230,26 +240,6 @@ class DiopiTensorWrapper {
230240
if (convertType_.isDtypeConverted()) {
231241
diopiCastDtype(ctx_, payload_, memoryFormatedTensor);
232242
}
233-
234-
// if (convertType_.isDtypeConverted() &&
235-
// !convertType_.isMemoryFormatConverted()) {
236-
// diopiCastDtype(ctx_, payload_, tmp_);
237-
// } else if (!convertType_.isDtypeConverted() &&
238-
// convertType_.isMemoryFormatConverted()) {
239-
// diopiCopyInp(ctx_, tmp_, payload_);
240-
// } else {
241-
// diopiDtype_t dtype;
242-
// diopiGetTensorDtype(tmp_, &dtype);
243-
// diopiSize_t size, stride, dstStride;
244-
// diopiGetTensorShape(payload_, &size);
245-
// diopiGetTensorStride(payload_, &stride);
246-
// diopiDevice_t device;
247-
// diopiGetTensorDevice(payload_, &device);
248-
// diopiTensorHandle_t tmp = nullptr;
249-
// diopiRequireTensor(ctx_, &tmp, &size, &stride, dtype, device);
250-
// diopiCopyInp(ctx_, tmp_, tmp);
251-
// diopiCastDtype(ctx_, payload_, tmp);
252-
// }
253243
}
254244

255245
public:

impl/camb/common/common.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ diopiError_t transpose(diopiContextHandle_t ctx, const DiopiTensor& inputTensor,
5555

5656
bool denseCheck(const DiopiTensor& src);
5757

58+
bool isSlice(const DiopiTensor& src);
59+
60+
bool isSparse(const DiopiTensor& src);
61+
62+
diopiError_t getDenseStride(const DiopiTensor& src, std::vector<int64_t>& dstStride);
63+
64+
diopiError_t sliceToDense(diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dst);
65+
66+
diopiError_t toDense(diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dst);
67+
5868
} // namespace camb
5969
} // namespace impl
6070

impl/camb/common/contiguous.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
#include <iostream>
88
#include <vector>
99

10-
#include "../common/debug.hpp"
1110
#include "common.hpp"
12-
1311
namespace impl {
1412
namespace camb {
1513

@@ -220,9 +218,20 @@ diopiError_t permuteTensor(DiopiTensor& t, const std::vector<int32_t>& order) {
220218

221219
// inplace contiguous, support NCHW <-> NHWC, NCDHW <-> NDHWC NCL <-> NLC
222220
diopiError_t contiguous(diopiContextHandle_t ctx, DiopiTensor& src, diopiMemoryFormat_t memoryFormat) {
221+
if (!denseCheck(src)) {
222+
DiopiTensor denseOut;
223+
toDense(ctx, src, denseOut);
224+
src = denseOut;
225+
if (memoryFormat == diopiMemoryFormat_t::Preserve) {
226+
// no need for further permute, if memoryFormat is Preserve.
227+
return diopiSuccess;
228+
}
229+
}
230+
223231
if (src.isContiguous(memoryFormat)) {
224232
return diopiSuccess;
225233
}
234+
226235
int64_t dim = src.dim();
227236
DIOPI_CHECK(dim <= 8, "only support less than 8d tensor currently");
228237
diopiMemoryFormat_t srcMemoryFormat;

0 commit comments

Comments
 (0)