Skip to content

Commit f081176

Browse files
committed
pass mpi rambo
1 parent aa27198 commit f081176

File tree

4 files changed

+67
-6
lines changed

4 files changed

+67
-6
lines changed

examples/rambo.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
Examples:
44
python rambo.py -nevts 10 -nout 10 -b sharpy -i 10000
5+
mpiexec -n 3 python rambo.py -nevts 64 -nout 64 -b sharpy -i 100
56
67
"""
78
import argparse

src/Service.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ struct DeferredService : public DeferredT<Service::service_promise_type,
5151
// drop from dep manager
5252
dm.drop(_a);
5353
// and from registry
54-
Registry::del(_a);
54+
dm.addReady(_a, [this](id_type guid) {
55+
assert(this->_a == guid);
56+
Registry::del(guid);
57+
});
5558
break;
5659
}
5760
case RUN:

src/idtr.cpp

+40-5
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ void unpack(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
268268
});
269269
}
270270

271+
/// copy contiguous block of data into a possibly strided array
272+
void unpack1(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
273+
const int64_t *strides, uint64_t ndim, void *out) {
274+
if (!in || !sizes || !strides || !out) {
275+
return;
276+
}
277+
dispatch(dtype, out, [sizes, strides, ndim, in](auto *out_) {
278+
auto in_ = static_cast<decltype(out_)>(in);
279+
SHARPY::forall(0, out_, sizes, strides, ndim, [&in_](auto *out) {
280+
*out = *in_;
281+
++in_;
282+
});
283+
});
284+
}
285+
271286
template <typename T>
272287
void copy_(uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
273288
const int64_t *strides, const uint64_t *chunks, uint64_t nd,
@@ -489,21 +504,41 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
489504
}
490505
}
491506

507+
int64_t oStride = std::accumulate(oDataStridesPtr, oDataStridesPtr + oNDims,
508+
1, std::multiplies<int64_t>());
509+
void *rBuff = oDataPtr;
510+
if (oStride != 1) {
511+
rBuff = new char[sizeof_dtype(sharpytype) * myOSz];
512+
}
513+
492514
SHARPY::Buffer sendbuff(totSSz * sizeof_dtype(sharpytype), 2);
493515
bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
494516
lsOffs.data(), lsEnds.data(), sendbuff.data());
495517
auto hdl = tc->alltoall(sendbuff.data(), sszs.data(), soffs.data(),
496-
sharpytype, oDataPtr, rszs.data(), roffs.data());
518+
sharpytype, rBuff, rszs.data(), roffs.data());
497519

498520
if (no_async) {
499521
tc->wait(hdl);
522+
if (oStride != 1) {
523+
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524+
oDataPtr);
525+
delete[] (char *)rBuff;
526+
}
500527
return nullptr;
501528
}
502529

503-
auto wait = [tc = tc, hdl = hdl, sendbuff = std::move(sendbuff),
504-
sszs = std::move(sszs), soffs = std::move(soffs),
505-
rszs = std::move(rszs),
506-
roffs = std::move(roffs)]() { tc->wait(hdl); };
530+
auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
531+
oDataStridesPtr, oNDims, oDataPtr,
532+
sendbuff = std::move(sendbuff), sszs = std::move(sszs),
533+
soffs = std::move(soffs), rszs = std::move(rszs),
534+
roffs = std::move(roffs)]() {
535+
tc->wait(hdl);
536+
if (oStride != 1) {
537+
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538+
oDataPtr);
539+
delete[] (char *)rBuff;
540+
}
541+
};
507542
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
508543
roffs.empty());
509544
return mkWaitHandle(std::move(wait));

test/test_setget.py

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from utils import device, runAndCompare
44

55
import sharpy as sp
6+
import numpy as np
67

78

89
class TestSetGet:
@@ -116,6 +117,27 @@ def doit(aapi, **kwargs):
116117

117118
assert runAndCompare(doit)
118119

120+
def test_setitem9(self):
121+
N = 3
122+
r1 = sp.random.rand(N)
123+
r2 = sp.random.rand(N)
124+
125+
a = sp.empty((N, 2))
126+
r11 = sp.reshape(r1, (N, 1))
127+
r22 = sp.reshape(r2, (N, 1))
128+
129+
# strided array
130+
a[:, 0] = r11
131+
a[:, 1] = r22
132+
133+
np_r1 = sp.to_numpy(r1)
134+
np_r2 = sp.to_numpy(r2)
135+
b = np.empty((N, 2))
136+
b[:, 0] = np_r1
137+
b[:, 1] = np_r2
138+
139+
assert np.allclose(sp.to_numpy(a), b)
140+
119141
def test_colon(self):
120142
a = sp.ones((16, 16), sp.float32, device=device)
121143
b = sp.zeros((16, 16), sp.float32, device=device)

0 commit comments

Comments
 (0)