@@ -268,6 +268,21 @@ void unpack(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
268
268
});
269
269
}
270
270
271
+ // / copy contiguous block of data into a possibly strided array
272
+ void unpack1 (void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
273
+ const int64_t *strides, uint64_t ndim, void *out) {
274
+ if (!in || !sizes || !strides || !out) {
275
+ return ;
276
+ }
277
+ dispatch (dtype, out, [sizes, strides, ndim, in](auto *out_) {
278
+ auto in_ = static_cast <decltype (out_)>(in);
279
+ SHARPY::forall (0 , out_, sizes, strides, ndim, [&in_](auto *out) {
280
+ *out = *in_;
281
+ ++in_;
282
+ });
283
+ });
284
+ }
285
+
271
286
template <typename T>
272
287
void copy_ (uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
273
288
const int64_t *strides, const uint64_t *chunks, uint64_t nd,
@@ -489,21 +504,41 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
489
504
}
490
505
}
491
506
507
+ int64_t oStride = std::accumulate (oDataStridesPtr, oDataStridesPtr + oNDims,
508
+ 1 , std::multiplies<int64_t >());
509
+ void *rBuff = oDataPtr;
510
+ if (oStride != 1 ) {
511
+ rBuff = new char [sizeof_dtype (sharpytype) * myOSz];
512
+ }
513
+
492
514
SHARPY::Buffer sendbuff (totSSz * sizeof_dtype (sharpytype), 2 );
493
515
bufferizeN (iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
494
516
lsOffs.data (), lsEnds.data (), sendbuff.data ());
495
517
auto hdl = tc->alltoall (sendbuff.data (), sszs.data (), soffs.data (),
496
- sharpytype, oDataPtr , rszs.data (), roffs.data ());
518
+ sharpytype, rBuff , rszs.data (), roffs.data ());
497
519
498
520
if (no_async) {
499
521
tc->wait (hdl);
522
+ if (oStride != 1 ) {
523
+ unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524
+ oDataPtr);
525
+ delete[] (char *)rBuff;
526
+ }
500
527
return nullptr ;
501
528
}
502
529
503
- auto wait = [tc = tc, hdl = hdl, sendbuff = std::move (sendbuff),
504
- sszs = std::move (sszs), soffs = std::move (soffs),
505
- rszs = std::move (rszs),
506
- roffs = std::move (roffs)]() { tc->wait (hdl); };
530
+ auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
531
+ oDataStridesPtr, oNDims, oDataPtr,
532
+ sendbuff = std::move (sendbuff), sszs = std::move (sszs),
533
+ soffs = std::move (soffs), rszs = std::move (rszs),
534
+ roffs = std::move (roffs)]() {
535
+ tc->wait (hdl);
536
+ if (oStride != 1 ) {
537
+ unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538
+ oDataPtr);
539
+ delete[] (char *)rBuff;
540
+ }
541
+ };
507
542
assert (sendbuff.empty () && sszs.empty () && soffs.empty () && rszs.empty () &&
508
543
roffs.empty ());
509
544
return mkWaitHandle (std::move (wait ));
0 commit comments