Skip to content

Commit e026504

Browse files
committed
fixing async reshape: use proper move semantics for wait lambdas
1 parent 0636bd7 commit e026504

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/idtr.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ template <typename T> class WaitHandle : public WaitHandleBase {
5151
T _fini;
5252

5353
public:
54-
WaitHandle(T fini) : _fini(fini) {}
54+
WaitHandle(T &&fini) : _fini(std::move(fini)) {}
5555
virtual void wait() override { _fini(); }
5656
};
5757

58-
template <typename T> WaitHandle<T> *mkWaitHandle(T fini) {
59-
return new WaitHandle<T>(fini);
58+
template <typename T> WaitHandle<T> *mkWaitHandle(T &&fini) {
59+
return new WaitHandle<T>(std::move(fini));
6060
};
6161

6262
extern "C" {
@@ -489,24 +489,24 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
489489
}
490490
}
491491

492-
SHARPY::Buffer outbuff(totSSz * sizeof_dtype(sharpytype), 2);
492+
SHARPY::Buffer sendbuff(totSSz * sizeof_dtype(sharpytype), 2);
493493
bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
494-
lsOffs.data(), lsEnds.data(), outbuff.data());
495-
auto hdl = tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), sharpytype,
496-
oDataPtr, rszs.data(), roffs.data());
494+
lsOffs.data(), lsEnds.data(), sendbuff.data());
495+
auto hdl = tc->alltoall(sendbuff.data(), sszs.data(), soffs.data(),
496+
sharpytype, oDataPtr, rszs.data(), roffs.data());
497497

498-
if (true || no_async) { // FIXME remove true once IMEX is fixed
498+
if (no_async) {
499499
tc->wait(hdl);
500500
return nullptr;
501501
}
502502

503-
auto wait = [tc = tc, hdl = hdl, outbuff = std::move(outbuff),
503+
auto wait = [tc = tc, hdl = hdl, sendbuff = std::move(sendbuff),
504504
sszs = std::move(sszs), soffs = std::move(soffs),
505505
rszs = std::move(rszs),
506506
roffs = std::move(roffs)]() { tc->wait(hdl); };
507-
assert(outbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
507+
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
508508
roffs.empty());
509-
return mkWaitHandle(wait);
509+
return mkWaitHandle(std::move(wait));
510510
}
511511

512512
/// @brief reshape array
@@ -918,7 +918,7 @@ void *_idtr_update_halo(SHARPY::DTypeId sharpytype, int64_t ndims,
918918
wait();
919919
return nullptr;
920920
}
921-
return mkWaitHandle(wait);
921+
return mkWaitHandle(std::move(wait));
922922
}
923923

924924
/// @brief templated wrapper for typed function versions calling

0 commit comments

Comments
 (0)