Skip to content

Commit ffc8cd4

Browse files
committed
use "SHARPY::is_contiguous" to check strided and simpliy wait code
1 parent c4aa251 commit ffc8cd4

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

src/idtr.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -504,43 +504,37 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
504504
}
505505
}
506506

507-
int64_t oStride = std::accumulate(oDataStridesPtr, oDataStridesPtr + oNDims,
508-
1, std::multiplies<int64_t>());
509-
void *rBuff = oDataPtr;
510-
if (oStride != 1) {
511-
rBuff = new char[sizeof_dtype(sharpytype) * myOSz];
512-
}
507+
bool isStrided =
508+
!SHARPY::is_contiguous(oDataShapePtr, oDataStridesPtr, oNDims);
509+
void *rBuff =
510+
isStrided ? new char[sizeof_dtype(sharpytype) * myOSz] : oDataPtr;
513511

514512
SHARPY::Buffer sendbuff(totSSz * sizeof_dtype(sharpytype), 2);
515513
bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
516514
lsOffs.data(), lsEnds.data(), sendbuff.data());
517515
auto hdl = tc->alltoall(sendbuff.data(), sszs.data(), soffs.data(),
518516
sharpytype, rBuff, rszs.data(), roffs.data());
519517

520-
if (no_async) {
521-
tc->wait(hdl);
522-
if (oStride != 1) {
523-
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524-
oDataPtr);
525-
delete[](char *) rBuff;
526-
}
527-
return nullptr;
528-
}
529-
530-
auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
518+
auto wait = [tc, hdl, isStrided, rBuff, sharpytype, oDataShapePtr,
531519
oDataStridesPtr, oNDims, oDataPtr,
532520
sendbuff = std::move(sendbuff), sszs = std::move(sszs),
533521
soffs = std::move(soffs), rszs = std::move(rszs),
534522
roffs = std::move(roffs)]() {
535523
tc->wait(hdl);
536-
if (oStride != 1) {
524+
if (isStrided) {
537525
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538526
oDataPtr);
539527
delete[](char *) rBuff;
540528
}
541529
};
542530
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
543531
roffs.empty());
532+
533+
if (no_async) {
534+
wait();
535+
return nullptr;
536+
}
537+
544538
return mkWaitHandle(std::move(wait));
545539
}
546540

test/test_random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from utils import device
34

45
import sharpy as sp
56

@@ -16,7 +17,7 @@ def seed(request):
1617

1718
def test_random_rand(shape, seed):
1819
sp.random.seed(seed)
19-
sp_data = sp.random.rand(*shape)
20+
sp_data = sp.random.rand(*shape, device=device)
2021

2122
np.random.seed(seed)
2223
np_data = np.random.rand(*shape)
@@ -30,7 +31,7 @@ def test_random_rand(shape, seed):
3031
@pytest.mark.parametrize("low,high", [(0, 1), (4, 10), (-100, 100)])
3132
def test_random_uniform(low, high, shape, seed):
3233
sp.random.seed(seed)
33-
sp_data = sp.random.uniform(low, high, shape)
34+
sp_data = sp.random.uniform(low, high, shape, device=device)
3435

3536
np.random.seed(seed)
3637
np_data = np.random.uniform(low, high, shape)

0 commit comments

Comments
 (0)