@@ -504,43 +504,37 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
504
504
}
505
505
}
506
506
507
- int64_t oStride = std::accumulate (oDataStridesPtr, oDataStridesPtr + oNDims,
508
- 1 , std::multiplies<int64_t >());
509
- void *rBuff = oDataPtr;
510
- if (oStride != 1 ) {
511
- rBuff = new char [sizeof_dtype (sharpytype) * myOSz];
512
- }
507
+ bool isStrided =
508
+ !SHARPY::is_contiguous (oDataShapePtr, oDataStridesPtr, oNDims);
509
+ void *rBuff =
510
+ isStrided ? new char [sizeof_dtype (sharpytype) * myOSz] : oDataPtr;
513
511
514
512
SHARPY::Buffer sendbuff (totSSz * sizeof_dtype (sharpytype), 2 );
515
513
bufferizeN (iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
516
514
lsOffs.data (), lsEnds.data (), sendbuff.data ());
517
515
auto hdl = tc->alltoall (sendbuff.data (), sszs.data (), soffs.data (),
518
516
sharpytype, rBuff, rszs.data (), roffs.data ());
519
517
520
- if (no_async) {
521
- tc->wait (hdl);
522
- if (oStride != 1 ) {
523
- unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524
- oDataPtr);
525
- delete[] (char *) rBuff;
526
- }
527
- return nullptr ;
528
- }
529
-
530
- auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
518
+ auto wait = [tc, hdl, isStrided, rBuff, sharpytype, oDataShapePtr,
531
519
oDataStridesPtr, oNDims, oDataPtr,
532
520
sendbuff = std::move (sendbuff), sszs = std::move (sszs),
533
521
soffs = std::move (soffs), rszs = std::move (rszs),
534
522
roffs = std::move (roffs)]() {
535
523
tc->wait (hdl);
536
- if (oStride != 1 ) {
524
+ if (isStrided ) {
537
525
unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538
526
oDataPtr);
539
527
delete[] (char *) rBuff;
540
528
}
541
529
};
542
530
assert (sendbuff.empty () && sszs.empty () && soffs.empty () && rszs.empty () &&
543
531
roffs.empty ());
532
+
533
+ if (no_async) {
534
+ wait ();
535
+ return nullptr ;
536
+ }
537
+
544
538
return mkWaitHandle (std::move (wait));
545
539
}
546
540
0 commit comments