Skip to content

Commit 3ee8ca5

Browse files
committed
implement gloo abort
1 parent 81925d1 commit 3ee8ca5

File tree

5 files changed

+90
-11
lines changed

5 files changed

+90
-11
lines changed

gloo/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(GLOO_COMMON_SRCS
22
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
33
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
4+
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
45
)
56

67
set(GLOO_COMMON_HDRS

gloo/common/error.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <atomic>
10+
#include <list>
11+
12+
#include "gloo/common/error.h"
13+
14+
namespace gloo {
15+
16+
17+
std::list<std::condition_variable *> _cvs;
18+
std::mutex _cvs_mutex;
19+
20+
std::atomic_bool _is_aborted_flag(false);
21+
22+
bool _is_aborted() {
23+
return _is_aborted_flag.load();
24+
}
25+
26+
void abort() {
27+
_is_aborted_flag.store(true);
28+
std::lock_guard<std::mutex> guard(_cvs_mutex);
29+
for(auto& cv : _cvs) {
30+
if(cv != NULL) {
31+
cv->notify_all();
32+
}
33+
}
34+
GLOO_THROW("GLOO ABORTED");
35+
}
36+
37+
void _register_cv(std::condition_variable *cv) {
38+
std::lock_guard<std::mutex> guard(_cvs_mutex);
39+
_cvs.push_back(cv);
40+
}
41+
42+
void _deregister_cv(std::condition_variable *cv) {
43+
std::lock_guard<std::mutex> guard(_cvs_mutex);
44+
_cvs.remove(cv);
45+
}
46+
} // namespace gloo

gloo/common/error.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <chrono>
1212
#include <exception>
13+
#include <condition_variable>
1314

1415
#include "gloo/common/string.h"
1516

@@ -20,6 +21,11 @@ namespace gloo {
2021

2122
const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero();
2223

24+
bool _is_aborted();
25+
void abort();
26+
void _register_cv(std::condition_variable *cv);
27+
void _deregister_cv(std::condition_variable *cv);
28+
2329
// A base class for all gloo runtime errors
2430
struct Exception : public std::runtime_error {
2531
Exception() = delete;

gloo/transport/tcp/unbound_buffer.cc

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
2828
recvRank_(-1),
2929
sendCompletions_(0),
3030
sendRank_(-1),
31-
shareableNonOwningPtr_(this) {}
31+
shareableNonOwningPtr_(this) {
32+
gloo::_register_cv(&recvCv_);
33+
gloo::_register_cv(&sendCv_);
34+
}
3235

33-
UnboundBuffer::~UnboundBuffer() {}
36+
UnboundBuffer::~UnboundBuffer() {
37+
gloo::_deregister_cv(&recvCv_);
38+
gloo::_deregister_cv(&sendCv_);
39+
}
3440

3541
void UnboundBuffer::handleRecvCompletion(int rank) {
3642
std::lock_guard<std::mutex> lock(m_);
@@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
6066
if (recvCompletions_ == 0) {
6167
auto done = recvCv_.wait_for(lock, timeout, [&] {
6268
throwIfException();
69+
if(gloo::_is_aborted()) {
70+
abortWaitRecv_ = true;
71+
}
6372
return abortWaitRecv_ || recvCompletions_ > 0;
6473
});
6574
if (!done) {
@@ -111,9 +120,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
111120

112121
if (sendCompletions_ == 0) {
113122
auto done = sendCv_.wait_for(lock, timeout, [&] {
114-
throwIfException();
115-
return abortWaitSend_ || sendCompletions_ > 0;
116-
});
123+
throwIfException();
124+
if(gloo::_is_aborted()) {
125+
abortWaitSend_ = true;
126+
}
127+
return abortWaitSend_ || sendCompletions_ > 0;
128+
});
117129
if (!done) {
118130
// Below, we let all pairs in the transport context know about this
119131
// application side timeout. This in turn will call into all pending

gloo/transport/uv/unbound_buffer.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
2828
recvRank_(-1),
2929
sendCompletions_(0),
3030
sendRank_(-1),
31-
shareableNonOwningPtr_(this) {}
31+
shareableNonOwningPtr_(this) {
32+
gloo::_register_cv(&recvCv_);
33+
gloo::_register_cv(&sendCv_);
34+
}
3235

33-
UnboundBuffer::~UnboundBuffer() {}
36+
UnboundBuffer::~UnboundBuffer() {
37+
gloo::_deregister_cv(&recvCv_);
38+
gloo::_deregister_cv(&sendCv_);
39+
}
3440

3541
void UnboundBuffer::handleRecvCompletion(int rank) {
3642
std::lock_guard<std::mutex> lock(mutex_);
@@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
5864
}
5965

6066
if (recvCompletions_ == 0) {
61-
auto done = recvCv_.wait_for(
62-
lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; });
67+
auto done = recvCv_.wait_for(lock, timeout, [&] {
68+
if(gloo::_is_aborted()) {
69+
abortWaitRecv_ = true;
70+
}
71+
return abortWaitRecv_ || recvCompletions_ > 0;
72+
});
6373
if (!done) {
6474
throw ::gloo::IoException(GLOO_ERROR_MSG(
6575
"Timed out waiting ",
@@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
94104
}
95105

96106
if (sendCompletions_ == 0) {
97-
auto done = sendCv_.wait_for(
98-
lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; });
107+
auto done = sendCv_.wait_for(lock, timeout, [&] {
108+
if(gloo::_is_aborted()) {
109+
abortWaitSend_ = true;
110+
}
111+
return abortWaitSend_ || sendCompletions_ > 0;
112+
});
99113
if (!done) {
100114
throw ::gloo::IoException(GLOO_ERROR_MSG(
101115
"Timed out waiting ",

0 commit comments

Comments
 (0)