Skip to content

Commit 2d13221

Browse files
authored
Halo wait routine takes halo memrefs as arguments (#59)
1 parent df112f4 commit 2d13221

File tree

4 files changed

+79
-20
lines changed

4 files changed

+79
-20
lines changed

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
47768ada136b299effc408de420710fcacc12fda
1+
06523703fa96f9731b9c3c1c7f6fee0914ef9f26

src/idtr.cpp

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,30 @@ template <typename T> WaitHandle<T> *mkWaitHandle(T fini) {
5151
return new WaitHandle<T>(fini);
5252
};
5353

54-
extern "C" {
55-
void _idtr_wait(WaitHandleBase *handle) {
54+
void _idtr_wait(WaitHandleBase *handle, int64_t lHaloRank, void *lHaloDescr,
55+
int64_t rHaloRank, void *rHaloDescr) {
5656
if (handle) {
5757
handle->wait();
5858
delete handle;
5959
}
6060
}
6161

62+
extern "C" {
63+
#define TYPED_WAIT(_sfx) \
64+
void _idtr_wait_##_sfx(WaitHandleBase *handle, int64_t lHaloRank, \
65+
void *lHaloDescr, int64_t rHaloRank, \
66+
void *rHaloDescr) { \
67+
return _idtr_wait(handle, lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
68+
}
69+
70+
TYPED_WAIT(f64);
71+
TYPED_WAIT(f32);
72+
TYPED_WAIT(i64);
73+
TYPED_WAIT(i32);
74+
TYPED_WAIT(i16);
75+
TYPED_WAIT(i8);
76+
TYPED_WAIT(i1);
77+
6278
#define NO_TRANSCEIVER
6379
#ifdef NO_TRANSCEIVER
6480
static void initMPIRuntime() {
@@ -486,6 +502,8 @@ struct UHCache {
486502
std::vector<int> _lSendSize, _rSendSize, _lSendOff, _rSendOff;
487503
// receive maps
488504
std::vector<int> _lRecvSize, _rRecvSize, _lRecvOff, _rRecvOff;
505+
// buffers
506+
Buffer _recvBuff, _sendLBuff, _sendRBuff;
489507
bool _bufferizeSend, _bufferizeLRecv, _bufferizeRRecv;
490508
// start and sizes for chunks from remotes if copies are needed
491509
int64_t _lTotalRecvSize, _rTotalRecvSize, _lTotalSendSize, _rTotalSendSize;
@@ -502,6 +520,7 @@ struct UHCache {
502520
std::vector<int> &&rSendSize, std::vector<int> &&lSendOff,
503521
std::vector<int> &&rSendOff, std::vector<int> &&lRecvSize,
504522
std::vector<int> &&rRecvSize, std::vector<int> &&lRecvOff,
523+
Buffer &&recvBuff, Buffer &&sendLBuff, Buffer &&sendRBuff,
505524
std::vector<int> &&rRecvOff, bool bufferizeSend, bool bufferizeLRecv,
506525
bool bufferizeRRecv, int64_t lTotalRecvSize, int64_t rTotalRecvSize,
507526
int64_t lTotalSendSize, int64_t rTotalSendSize)
@@ -515,10 +534,11 @@ struct UHCache {
515534
_lSendOff(std::move(lSendOff)), _rSendOff(std::move(rSendOff)),
516535
_lRecvSize(std::move(lRecvSize)), _rRecvSize(std::move(rRecvSize)),
517536
_lRecvOff(std::move(lRecvOff)), _rRecvOff(std::move(rRecvOff)),
518-
_bufferizeSend(bufferizeSend), _bufferizeLRecv(bufferizeLRecv),
519-
_bufferizeRRecv(bufferizeRRecv), _lTotalRecvSize(lTotalRecvSize),
520-
_rTotalRecvSize(rTotalRecvSize), _lTotalSendSize(lTotalSendSize),
521-
_rTotalSendSize(rTotalSendSize) {}
537+
_recvBuff(std::move(recvBuff)), _sendLBuff(std::move(sendLBuff)),
538+
_sendRBuff(std::move(sendRBuff)), _bufferizeSend(bufferizeSend),
539+
_bufferizeLRecv(bufferizeLRecv), _bufferizeRRecv(bufferizeRRecv),
540+
_lTotalRecvSize(lTotalRecvSize), _rTotalRecvSize(rTotalRecvSize),
541+
_lTotalSendSize(lTotalSendSize), _rTotalSendSize(rTotalSendSize) {}
522542
UHCache &operator=(const UHCache &) = delete;
523543
UHCache &operator=(UHCache &&) = default;
524544
};
@@ -712,35 +732,40 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
712732
}
713733
cache = &(cIt->second);
714734

715-
Buffer recvBuff(0), sendBuff(0);
716735
if (cache->_bufferizeLRecv || cache->_bufferizeRRecv) {
717-
recvBuff.resize(std::max(cache->_lTotalRecvSize, cache->_rTotalRecvSize) *
718-
sizeof_dtype(ddpttype));
736+
cache->_recvBuff.resize(
737+
std::max(cache->_lTotalRecvSize, cache->_rTotalRecvSize) *
738+
sizeof_dtype(ddpttype));
719739
}
720740
if (cache->_bufferizeSend) {
721-
sendBuff.resize(std::max(cache->_lTotalSendSize, cache->_rTotalSendSize) *
722-
sizeof_dtype(ddpttype));
741+
cache->_sendLBuff.resize(cache->_lTotalSendSize * sizeof_dtype(ddpttype));
742+
cache->_sendRBuff.resize(cache->_rTotalSendSize * sizeof_dtype(ddpttype));
723743
}
724744

725-
void *lRecvData = cache->_bufferizeLRecv ? recvBuff.data() : leftHaloData;
726-
void *rRecvData = cache->_bufferizeRRecv ? recvBuff.data() : rightHaloData;
727-
void *sendData = cache->_bufferizeSend ? sendBuff.data() : ownedData;
745+
void *lRecvData =
746+
cache->_bufferizeLRecv ? cache->_recvBuff.data() : leftHaloData;
747+
void *rRecvData =
748+
cache->_bufferizeRRecv ? cache->_recvBuff.data() : rightHaloData;
749+
void *lSendData =
750+
cache->_bufferizeSend ? cache->_sendLBuff.data() : ownedData;
751+
void *rSendData =
752+
cache->_bufferizeSend ? cache->_sendRBuff.data() : ownedData;
728753

729754
// communicate left/right halos
730755
if (cache->_bufferizeSend) {
731756
bufferize(ownedData, ddpttype, ownedShape, ownedStride,
732757
cache->_lBufferStart.data(), cache->_lBufferSize.data(), ndims,
733-
nworkers, sendBuff.data());
758+
nworkers, cache->_sendLBuff.data());
734759
}
735-
auto lwh = tc->alltoall(sendData, cache->_lSendSize.data(),
760+
auto lwh = tc->alltoall(lSendData, cache->_lSendSize.data(),
736761
cache->_lSendOff.data(), ddpttype, lRecvData,
737762
cache->_lRecvSize.data(), cache->_lRecvOff.data());
738763
if (cache->_bufferizeSend) {
739764
bufferize(ownedData, ddpttype, ownedShape, ownedStride,
740765
cache->_rBufferStart.data(), cache->_rBufferSize.data(), ndims,
741-
nworkers, sendBuff.data());
766+
nworkers, cache->_sendRBuff.data());
742767
}
743-
auto rwh = tc->alltoall(sendData, cache->_rSendSize.data(),
768+
auto rwh = tc->alltoall(rSendData, cache->_rSendSize.data(),
744769
cache->_rSendOff.data(), ddpttype, rRecvData,
745770
cache->_rRecvSize.data(), cache->_rRecvOff.data());
746771

@@ -760,7 +785,6 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
760785
}
761786
};
762787

763-
// FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
764788
if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ||
765789
getenv("DDPT_NO_ASYNC")) {
766790
wait();

test/test_ewb.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def test_add3(self):
4747
v = 16 * 16 * 3
4848
assert float(r1) == v
4949

50+
def test_add4(self):
51+
for dtyp in mpi_idtypes:
52+
n = 16
53+
a = dt.fromfunction(lambda i, j: i, (n, n), dtype=dtyp)
54+
b = dt.ones((n, n), dtyp)
55+
c = a + b
56+
a[:, :] = a[:, :] + c[:, :]
57+
r1 = dt.sum(a, [0, 1])
58+
v = n * n * n
59+
assert float(r1) == v
60+
5061
def test_add_mul(self):
5162
def doit(aapi):
5263
a = aapi.zeros((16, 16), dtype=aapi.int64)

test/test_setget.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,30 @@ def doit(aapi):
6060

6161
assert runAndCompare(doit)
6262

63+
def test_setitem6(self):
64+
def doit(aapi):
65+
n = 16
66+
a = aapi.fromfunction(lambda i, j: i, (n, n), dtype=aapi.float64)
67+
b = aapi.zeros((n + 1, n + 1), aapi.float64)
68+
69+
b[1:n, 1:n] = a[1:n, 1:n]
70+
return b
71+
72+
assert runAndCompare(doit)
73+
74+
def test_setitem7(self):
75+
# Note: assert halo does not segfault
76+
def doit(aapi):
77+
n = 1024
78+
a = aapi.fromfunction(lambda i, j: i, (n, n), dtype=aapi.float64)
79+
b = aapi.zeros((n, n), aapi.float64)
80+
81+
b[1:n, 1:n] = a[1:n, 1:n]
82+
b[0, 1:n] = a[0, 1:n]
83+
return b
84+
85+
assert runAndCompare(doit)
86+
6387
def test_colon(self):
6488
a = dt.ones((16, 16), dt.float64)
6589
b = dt.zeros((16, 16), dt.float64)

0 commit comments

Comments
 (0)