@@ -51,14 +51,30 @@ template <typename T> WaitHandle<T> *mkWaitHandle(T fini) {
51
51
return new WaitHandle<T>(fini);
52
52
};
53
53
54
- extern " C " {
55
- void _idtr_wait (WaitHandleBase *handle ) {
54
+ void _idtr_wait (WaitHandleBase *handle, int64_t lHaloRank, void *lHaloDescr,
55
+ int64_t rHaloRank, void *rHaloDescr ) {
56
56
if (handle) {
57
57
handle->wait ();
58
58
delete handle;
59
59
}
60
60
}
61
61
62
+ extern " C" {
63
+ #define TYPED_WAIT (_sfx ) \
64
+ void _idtr_wait_##_sfx(WaitHandleBase *handle, int64_t lHaloRank, \
65
+ void *lHaloDescr, int64_t rHaloRank, \
66
+ void *rHaloDescr) { \
67
+ return _idtr_wait (handle, lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
68
+ }
69
+
70
+ TYPED_WAIT (f64 );
71
+ TYPED_WAIT (f32 );
72
+ TYPED_WAIT (i64 );
73
+ TYPED_WAIT (i32 );
74
+ TYPED_WAIT (i16 );
75
+ TYPED_WAIT (i8 );
76
+ TYPED_WAIT (i1);
77
+
62
78
#define NO_TRANSCEIVER
63
79
#ifdef NO_TRANSCEIVER
64
80
static void initMPIRuntime () {
@@ -486,6 +502,8 @@ struct UHCache {
486
502
std::vector<int > _lSendSize, _rSendSize, _lSendOff, _rSendOff;
487
503
// receive maps
488
504
std::vector<int > _lRecvSize, _rRecvSize, _lRecvOff, _rRecvOff;
505
+ // buffers
506
+ Buffer _recvBuff, _sendLBuff, _sendRBuff;
489
507
bool _bufferizeSend, _bufferizeLRecv, _bufferizeRRecv;
490
508
// start and sizes for chunks from remotes if copies are needed
491
509
int64_t _lTotalRecvSize, _rTotalRecvSize, _lTotalSendSize, _rTotalSendSize;
@@ -502,6 +520,7 @@ struct UHCache {
502
520
std::vector<int > &&rSendSize, std::vector<int > &&lSendOff,
503
521
std::vector<int > &&rSendOff, std::vector<int > &&lRecvSize,
504
522
std::vector<int > &&rRecvSize, std::vector<int > &&lRecvOff,
523
+ Buffer &&recvBuff, Buffer &&sendLBuff, Buffer &&sendRBuff,
505
524
std::vector<int > &&rRecvOff, bool bufferizeSend, bool bufferizeLRecv,
506
525
bool bufferizeRRecv, int64_t lTotalRecvSize, int64_t rTotalRecvSize,
507
526
int64_t lTotalSendSize, int64_t rTotalSendSize)
@@ -515,10 +534,11 @@ struct UHCache {
515
534
_lSendOff(std::move(lSendOff)), _rSendOff(std::move(rSendOff)),
516
535
_lRecvSize(std::move(lRecvSize)), _rRecvSize(std::move(rRecvSize)),
517
536
_lRecvOff(std::move(lRecvOff)), _rRecvOff(std::move(rRecvOff)),
518
- _bufferizeSend(bufferizeSend), _bufferizeLRecv(bufferizeLRecv),
519
- _bufferizeRRecv(bufferizeRRecv), _lTotalRecvSize(lTotalRecvSize),
520
- _rTotalRecvSize(rTotalRecvSize), _lTotalSendSize(lTotalSendSize),
521
- _rTotalSendSize(rTotalSendSize) {}
537
+ _recvBuff(std::move(recvBuff)), _sendLBuff(std::move(sendLBuff)),
538
+ _sendRBuff(std::move(sendRBuff)), _bufferizeSend(bufferizeSend),
539
+ _bufferizeLRecv(bufferizeLRecv), _bufferizeRRecv(bufferizeRRecv),
540
+ _lTotalRecvSize(lTotalRecvSize), _rTotalRecvSize(rTotalRecvSize),
541
+ _lTotalSendSize(lTotalSendSize), _rTotalSendSize(rTotalSendSize) {}
522
542
UHCache &operator =(const UHCache &) = delete ;
523
543
UHCache &operator =(UHCache &&) = default ;
524
544
};
@@ -712,35 +732,40 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
712
732
}
713
733
cache = &(cIt->second );
714
734
715
- Buffer recvBuff (0 ), sendBuff (0 );
716
735
if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ) {
717
- recvBuff.resize (std::max (cache->_lTotalRecvSize , cache->_rTotalRecvSize ) *
718
- sizeof_dtype (ddpttype));
736
+ cache->_recvBuff .resize (
737
+ std::max (cache->_lTotalRecvSize , cache->_rTotalRecvSize ) *
738
+ sizeof_dtype (ddpttype));
719
739
}
720
740
if (cache->_bufferizeSend ) {
721
- sendBuff .resize (std::max ( cache->_lTotalSendSize , cache-> _rTotalSendSize ) *
722
- sizeof_dtype (ddpttype));
741
+ cache-> _sendLBuff .resize (cache->_lTotalSendSize * sizeof_dtype (ddpttype));
742
+ cache-> _sendRBuff . resize (cache-> _rTotalSendSize * sizeof_dtype (ddpttype));
723
743
}
724
744
725
- void *lRecvData = cache->_bufferizeLRecv ? recvBuff.data () : leftHaloData;
726
- void *rRecvData = cache->_bufferizeRRecv ? recvBuff.data () : rightHaloData;
727
- void *sendData = cache->_bufferizeSend ? sendBuff.data () : ownedData;
745
+ void *lRecvData =
746
+ cache->_bufferizeLRecv ? cache->_recvBuff .data () : leftHaloData;
747
+ void *rRecvData =
748
+ cache->_bufferizeRRecv ? cache->_recvBuff .data () : rightHaloData;
749
+ void *lSendData =
750
+ cache->_bufferizeSend ? cache->_sendLBuff .data () : ownedData;
751
+ void *rSendData =
752
+ cache->_bufferizeSend ? cache->_sendRBuff .data () : ownedData;
728
753
729
754
// communicate left/right halos
730
755
if (cache->_bufferizeSend ) {
731
756
bufferize (ownedData, ddpttype, ownedShape, ownedStride,
732
757
cache->_lBufferStart .data (), cache->_lBufferSize .data (), ndims,
733
- nworkers, sendBuff .data ());
758
+ nworkers, cache-> _sendLBuff .data ());
734
759
}
735
- auto lwh = tc->alltoall (sendData , cache->_lSendSize .data (),
760
+ auto lwh = tc->alltoall (lSendData , cache->_lSendSize .data (),
736
761
cache->_lSendOff .data (), ddpttype, lRecvData,
737
762
cache->_lRecvSize .data (), cache->_lRecvOff .data ());
738
763
if (cache->_bufferizeSend ) {
739
764
bufferize (ownedData, ddpttype, ownedShape, ownedStride,
740
765
cache->_rBufferStart .data (), cache->_rBufferSize .data (), ndims,
741
- nworkers, sendBuff .data ());
766
+ nworkers, cache-> _sendRBuff .data ());
742
767
}
743
- auto rwh = tc->alltoall (sendData , cache->_rSendSize .data (),
768
+ auto rwh = tc->alltoall (rSendData , cache->_rSendSize .data (),
744
769
cache->_rSendOff .data (), ddpttype, rRecvData,
745
770
cache->_rRecvSize .data (), cache->_rRecvOff .data ());
746
771
@@ -760,7 +785,6 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
760
785
}
761
786
};
762
787
763
- // FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
764
788
if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ||
765
789
getenv (" DDPT_NO_ASYNC" )) {
766
790
wait ();
0 commit comments