13
13
#include < ostream>
14
14
#include < vector>
15
15
16
+ bool denseCheckAdaptor (diopiSize_t shape, diopiSize_t stride);
17
+
16
18
std::vector<int64_t > calcStrides (diopiSize_t size, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous);
17
19
18
20
bool isLikeChannelsLast (diopiConstTensorHandle_t tensor, bool checkContiguous, diopiMemoryFormat_t format = diopiMemoryFormat_t::ChannelsLast);
@@ -52,7 +54,7 @@ struct RemoveConst<diopiConstTensorHandle_t> {
52
54
53
55
class NoCast {
54
56
public:
55
- static bool getDstDtype (diopiDtype_t srcDtype, diopiDtype_t & targetDtype) {
57
+ static bool getDstDtype (diopiDtype_t srcDtype, diopiDtype_t& targetDtype) {
56
58
bool convert = false ;
57
59
switch (srcDtype) {
58
60
default :
@@ -63,7 +65,7 @@ class NoCast {
63
65
};
64
66
65
67
template <class T , class strategy = NoCast>
66
- ConvertType castImpl (diopiContextHandle_t ctx, T src, T * dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
68
+ ConvertType castImpl (diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
67
69
ConvertType convertType;
68
70
if (!src) {
69
71
*dst = src;
@@ -77,13 +79,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
77
79
strategy::getDstDtype (srcDtype, dstDtype);
78
80
std::vector<diopiMemoryFormat_t> targetMemoryFormats = obtainTargetMemoryFormats (srcSize.len , supportMemoryFormats);
79
81
diopiTensorHandle_t memoryFormatedTensor = nullptr ;
80
-
81
82
// convertDtype
82
83
83
84
diopiDevice_t device;
84
85
diopiGetTensorDevice (src, &device);
85
86
diopiTensorHandle_t tmp0 = nullptr ;
86
87
bool needConvertDtype = srcDtype != dstDtype;
88
+
87
89
if (needConvertDtype) {
88
90
diopiRequireTensor (ctx, &tmp0, &srcSize, &srcStride, dstDtype, device);
89
91
diopiCastDtype (ctx, tmp0, src);
@@ -108,6 +110,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
108
110
}
109
111
diopiSize_t dstStride = srcStride;
110
112
diopiSize_t dstSize = srcSize;
113
+ if (!targetMemoryFormats.empty ()) {
114
+ if (!denseCheckAdaptor (srcSize, srcStride) && supportMemoryFormats[0 ] == diopiMemoryFormat_t::Preserve) {
115
+ targetMemoryFormats.push_back (diopiMemoryFormat_t::Preserve);
116
+ needConvertMemoryFormat = true ;
117
+ }
118
+ }
119
+
111
120
if (needConvertMemoryFormat) {
112
121
diopiContiguous (ctx, &memoryFormatedTensor, tmp0, targetMemoryFormats[0 ]);
113
122
convertType.setMemoryFormatConverted ();
@@ -122,7 +131,7 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
122
131
}
123
132
124
133
template <class T , class strategy >
125
- ConvertType requireTensorIfMemoryFormatConvert (diopiContextHandle_t ctx, T src, T * dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
134
+ ConvertType requireTensorIfMemoryFormatConvert (diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
126
135
ConvertType convertType;
127
136
if (!src) {
128
137
*dst = src;
@@ -139,6 +148,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
139
148
if (targetMemoryFormats.empty ()) {
140
149
needConvertMemoryFormat = false ;
141
150
}
151
+
142
152
for (auto memoryFormat : targetMemoryFormats) {
143
153
if (isContiguous (srcSize, srcStride, memoryFormat)) {
144
154
needConvertMemoryFormat = false ;
@@ -174,7 +184,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
174
184
}
175
185
176
186
template <typename Adaptor, typename ... Args>
177
- void dispatchDiopi (diopiContextHandle_t ctx, Args &&...args) {
187
+ void dispatchDiopi (diopiContextHandle_t ctx, Args&&... args) {
178
188
auto adaptor = Adaptor ();
179
189
adaptor (ctx, std::forward<Args>(args)...);
180
190
}
@@ -195,10 +205,10 @@ template <class strategy = NoCast>
195
205
class DiopiTensorWrapper {
196
206
public:
197
207
// forbid copy/move constructor/assignment
198
- DiopiTensorWrapper (const DiopiTensorWrapper &) = delete ;
199
- DiopiTensorWrapper & operator =(const DiopiTensorWrapper &) = delete ;
200
- DiopiTensorWrapper (DiopiTensorWrapper &&) = delete ;
201
- DiopiTensorWrapper & operator =(DiopiTensorWrapper &&) = delete ;
208
+ DiopiTensorWrapper (const DiopiTensorWrapper&) = delete ;
209
+ DiopiTensorWrapper& operator =(const DiopiTensorWrapper&) = delete ;
210
+ DiopiTensorWrapper (DiopiTensorWrapper&&) = delete ;
211
+ DiopiTensorWrapper& operator =(DiopiTensorWrapper&&) = delete ;
202
212
203
213
private:
204
214
diopiContextHandle_t ctx_;
@@ -230,26 +240,6 @@ class DiopiTensorWrapper {
230
240
if (convertType_.isDtypeConverted ()) {
231
241
diopiCastDtype (ctx_, payload_, memoryFormatedTensor);
232
242
}
233
-
234
- // if (convertType_.isDtypeConverted() &&
235
- // !convertType_.isMemoryFormatConverted()) {
236
- // diopiCastDtype(ctx_, payload_, tmp_);
237
- // } else if (!convertType_.isDtypeConverted() &&
238
- // convertType_.isMemoryFormatConverted()) {
239
- // diopiCopyInp(ctx_, tmp_, payload_);
240
- // } else {
241
- // diopiDtype_t dtype;
242
- // diopiGetTensorDtype(tmp_, &dtype);
243
- // diopiSize_t size, stride, dstStride;
244
- // diopiGetTensorShape(payload_, &size);
245
- // diopiGetTensorStride(payload_, &stride);
246
- // diopiDevice_t device;
247
- // diopiGetTensorDevice(payload_, &device);
248
- // diopiTensorHandle_t tmp = nullptr;
249
- // diopiRequireTensor(ctx_, &tmp, &size, &stride, dtype, device);
250
- // diopiCopyInp(ctx_, tmp_, tmp);
251
- // diopiCastDtype(ctx_, payload_, tmp);
252
- // }
253
243
}
254
244
255
245
public:
0 commit comments