2
2
3
3
#include < boost/mpi/communicator.hpp>
4
4
#include < boost/mpi/environment.hpp>
5
+ #include < memory>
5
6
6
7
class UnreadMessagesDetector : public ::testing::EmptyTestEventListener {
7
8
public:
8
- UnreadMessagesDetector (boost::mpi::communicator world ) : world_ (std::move(world )) {}
9
+ UnreadMessagesDetector (boost::mpi::communicator com ) : com_ (std::move(com )) {}
9
10
10
11
void OnTestEnd (const ::testing::TestInfo& test_info) override {
11
- world_ .barrier ();
12
- if (const auto msg = world_ .iprobe (boost::mpi::any_source, boost::mpi::any_tag)) {
12
+ com_ .barrier ();
13
+ if (const auto msg = com_ .iprobe (boost::mpi::any_source, boost::mpi::any_tag)) {
13
14
fprintf (
14
15
stderr,
15
16
" [ PROCESS %d ] [ FAILED ] %s.%s: MPI message queue has an unread message from process %d with tag %d\n " ,
16
- world_ .rank (), test_info.test_suite_name (), test_info.name (), msg->source (), msg->tag ());
17
+ com_ .rank (), test_info.test_suite_name (), test_info.name (), msg->source (), msg->tag ());
17
18
exit (2 );
18
19
}
19
- world_ .barrier ();
20
+ com_ .barrier ();
20
21
}
21
22
22
23
private:
23
- boost::mpi::communicator world_;
24
+ boost::mpi::communicator com_;
25
+ };
26
+
27
+ class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
28
+ public:
29
+ WorkerTestFailurePrinter (std::shared_ptr<::testing::TestEventListener> base, boost::mpi::communicator com)
30
+ : base_(std::move(base)), com_(std::move(com)) {}
31
+
32
+ void OnTestEnd (const ::testing::TestInfo& test_info) override {
33
+ if (test_info.result ()->Passed ()) {
34
+ return ;
35
+ }
36
+ PrintProcessRank ();
37
+ base_->OnTestEnd (test_info);
38
+ }
39
+
40
+ void OnTestPartResult (const ::testing::TestPartResult& test_part_result) override {
41
+ if (test_part_result.passed () || test_part_result.skipped ()) {
42
+ return ;
43
+ }
44
+ PrintProcessRank ();
45
+ base_->OnTestPartResult (test_part_result);
46
+ }
47
+
48
+ private:
49
+ void PrintProcessRank () const { printf (" [ PROCESS %d ] " , com_.rank ()); }
50
+
51
+ std::shared_ptr<::testing::TestEventListener> base_;
52
+ boost::mpi::communicator com_;
24
53
};
25
54
26
55
int main (int argc, char ** argv) {
@@ -29,8 +58,9 @@ int main(int argc, char** argv) {
29
58
30
59
::testing::InitGoogleTest (&argc, argv);
31
60
auto & listeners = ::testing::UnitTest::GetInstance ()->listeners ();
32
- if (world.rank () != 0 ) {
33
- delete listeners.Release (listeners.default_result_printer ());
61
+ if (world.rank () != 0 && (argc < 2 || argv[1 ] != std::string (" --print-workers" ))) {
62
+ auto * listener = listeners.Release (listeners.default_result_printer ());
63
+ listeners.Append (new WorkerTestFailurePrinter (std::shared_ptr<::testing::TestEventListener>(listener), world));
34
64
}
35
65
listeners.Append (new UnreadMessagesDetector (world));
36
66
return RUN_ALL_TESTS ();
0 commit comments