Skip to content

Commit a7a49d1

Browse files
committed
add test
1 parent 0ad4df8 commit a7a49d1

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
find_package(OpenSSL 1.1 REQUIRED EXACT)
22

33
set(GLOO_TEST_SRCS
4+
"${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc"
45
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc"
56
"${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc"
67
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc"

gloo/test/abort_test.cc

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <functional>
10+
#include <thread>
11+
#include <vector>
12+
13+
#include "gloo/barrier_all_to_all.h"
14+
#include "gloo/barrier_all_to_one.h"
15+
#include "gloo/broadcast.h"
16+
#include "gloo/test/base_test.h"
17+
18+
namespace gloo {
19+
namespace test {
20+
namespace {
21+
22+
// Function to instantiate and run algorithm.
23+
using Func = void(std::shared_ptr<::gloo::Context>);
24+
25+
// Test parameterization.
26+
using Param = std::tuple<Transport, int, std::function<Func>>;
27+
28+
// Test fixture.
29+
class BarrierTest : public BaseTest,
30+
public ::testing::WithParamInterface<Param> {};
31+
32+
TEST_P(BarrierTest, SinglePointer) {
33+
const auto transport = std::get<0>(GetParam());
34+
const auto contextSize = std::get<1>(GetParam());
35+
const auto fn = std::get<2>(GetParam());
36+
37+
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
38+
fn(context);
39+
});
40+
}
41+
42+
static std::function<Func> barrierAllToAll =
43+
[](std::shared_ptr<::gloo::Context> context) {
44+
::gloo::BarrierAllToAll algorithm(context);
45+
algorithm.run();
46+
};
47+
48+
INSTANTIATE_TEST_CASE_P(
49+
BarrierAllToAll,
50+
BarrierTest,
51+
::testing::Combine(
52+
::testing::ValuesIn(kTransportsForClassAlgorithms),
53+
::testing::Range(2, 16),
54+
::testing::Values(barrierAllToAll)));
55+
56+
static std::function<Func> barrierAllToOne =
57+
[](std::shared_ptr<::gloo::Context> context) {
58+
::gloo::BarrierAllToOne algorithm(context);
59+
algorithm.run();
60+
};
61+
62+
INSTANTIATE_TEST_CASE_P(
63+
BarrierAllToOne,
64+
BarrierTest,
65+
::testing::Combine(
66+
::testing::ValuesIn(kTransportsForClassAlgorithms),
67+
::testing::Range(2, 16),
68+
::testing::Values(barrierAllToOne)));
69+
70+
// Synchronized version of std::chrono::clock::now().
71+
// All processes participating in the specified context will
72+
// see the same value.
73+
template <typename clock>
74+
std::chrono::time_point<clock> syncNow(std::shared_ptr<Context> context) {
75+
const typename clock::time_point now = clock::now();
76+
typename clock::duration::rep count = now.time_since_epoch().count();
77+
BroadcastOptions opts(context);
78+
opts.setRoot(0);
79+
opts.setOutput(&count, 1);
80+
broadcast(opts);
81+
return typename clock::time_point(typename clock::duration(count));
82+
}
83+
84+
using NewParam = std::tuple<Transport, int>;
85+
86+
class BarrierNewTest : public BaseTest,
87+
public ::testing::WithParamInterface<NewParam> {};
88+
89+
TEST_P(BarrierNewTest, Default) {
90+
const auto transport = std::get<0>(GetParam());
91+
const auto contextSize = std::get<1>(GetParam());
92+
93+
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
94+
BarrierOptions opts(context);
95+
96+
// Run barrier to synchronize processes after starting.
97+
barrier(opts);
98+
99+
// Take turns in sleeping for a bit and checking that all processes
100+
// saw that artificial delay through the barrier.
101+
auto singleProcessDelay = std::chrono::milliseconds(1000);
102+
for (size_t i = 0; i < context->size; i++) {
103+
const auto start = syncNow<std::chrono::high_resolution_clock>(context);
104+
if (i == context->rank) {
105+
/* sleep override */
106+
std::this_thread::sleep_for(singleProcessDelay);
107+
}
108+
109+
barrier(opts);
110+
abort();
111+
112+
// Expect all processes to have taken less than the sleep, as abort was called
113+
auto stop = std::chrono::high_resolution_clock::now();
114+
auto delta = std::chrono::duration_cast<decltype(singleProcessDelay)>(
115+
stop - start);
116+
ASSERT_LE(delta.count(), singleProcessDelay.count());
117+
}
118+
});
119+
}
120+
121+
INSTANTIATE_TEST_CASE_P(
122+
BarrierNewDefault,
123+
BarrierNewTest,
124+
::testing::Combine(
125+
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
126+
::testing::Values(1, 2, 4, 7)));
127+
128+
TEST_F(BarrierNewTest, TestTimeout) {
129+
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
130+
BarrierOptions opts(context);
131+
opts.setTimeout(std::chrono::milliseconds(10));
132+
if (context->rank == 0) {
133+
try {
134+
barrier(opts);
135+
FAIL() << "Expected exception to be thrown";
136+
} catch (::gloo::IoException& e) {
137+
ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos);
138+
}
139+
}
140+
});
141+
}
142+
143+
} // namespace
144+
} // namespace test
145+
} // namespace gloo

0 commit comments

Comments
 (0)