Skip to content

Commit d0350d2

Browse files
shasson5gleon99
andauthored
Support test variants (#742)
* TEST: Support test variants * TEST: CR fix * TEST: minor fix * TEST/COMMON: CR fix * NIXL/TEST: clang * NIXl/TEST: clang fix --------- Co-authored-by: Leonid Genkin <[email protected]>
1 parent 7e762e0 commit d0350d2

File tree

5 files changed

+56
-20
lines changed

5 files changed

+56
-20
lines changed

src/plugins/ucx/ucx_backend.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,8 +1102,12 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
11021102
err_handling_mode = ucx_err_mode_from_string(err_handling_mode_it->second);
11031103
}
11041104

1105+
const auto engine_config_it = custom_params->find("engine_config");
1106+
const auto engine_config =
1107+
(engine_config_it != custom_params->end()) ? engine_config_it->second : "";
1108+
11051109
uc = std::make_unique<nixlUcxContext>(
1106-
devs, init_params.enableProgTh, num_workers, init_params.syncMode);
1110+
devs, init_params.enableProgTh, num_workers, init_params.syncMode, engine_config);
11071111

11081112
for (size_t i = 0; i < num_workers; i++) {
11091113
uws.emplace_back(std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));

src/utils/ucx/ucx_utils.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept
408408
nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
409409
bool prog_thread,
410410
unsigned long num_workers,
411-
nixl_thread_sync_t sync_mode) {
411+
nixl_thread_sync_t sync_mode,
412+
const std::string &engine_config) {
412413
ucp_params_t ucp_params;
413414
unsigned major_version, minor_version, release_number;
414415
ucp_get_version(&major_version, &minor_version, &release_number);
@@ -454,6 +455,18 @@ nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
454455
config.modify("MAX_COMPONENT_MDS", "32");
455456
}
456457

458+
std::string elem;
459+
std::stringstream stream(engine_config);
460+
461+
while (std::getline(stream, elem, ',')) {
462+
std::string_view elem_view = elem;
463+
size_t pos = elem_view.find('=');
464+
465+
if (pos != std::string::npos) {
466+
config.modify(elem_view.substr(0, pos), elem_view.substr(pos + 1));
467+
}
468+
}
469+
457470
const auto status = ucp_init (&ucp_params, config.getUcpConfig(), &ctx);
458471
if (status != UCS_OK) {
459472
throw std::runtime_error ("Failed to create UCX context: " +

src/utils/ucx/ucx_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ class nixlUcxContext {
201201
nixlUcxContext(std::vector<std::string> devices,
202202
bool prog_thread,
203203
unsigned long num_workers,
204-
nixl_thread_sync_t sync_mode);
204+
nixl_thread_sync_t sync_mode,
205+
const std::string &engine_conf = "");
205206
~nixlUcxContext();
206207

207208
/* Memory management */

test/gtest/common.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <stack>
2626
#include <optional>
2727
#include <mutex>
28+
#include "gtest/gtest.h"
2829

2930
namespace gtest {
3031
constexpr const char *
@@ -102,6 +103,29 @@ class PortAllocator {
102103
uint16_t _max_port = MAX_PORT;
103104
};
104105

106+
struct nixlTestParam {
107+
std::string backendName;
108+
bool progressThreadEnabled;
109+
unsigned numWorkers;
110+
unsigned numThreads;
111+
std::string engineConfig;
112+
};
113+
114+
using nixl_test_t = testing::TestWithParam<nixlTestParam>;
115+
105116
} // namespace gtest
106117

118+
#define NIXL_INSTANTIATE_TEST(_test_name, \
119+
_test_case, \
120+
_backend, \
121+
_progress_thread_enabled, \
122+
_num_workers, \
123+
_num_threads, \
124+
_engine_config) \
125+
INSTANTIATE_TEST_SUITE_P( \
126+
_test_name, \
127+
_test_case, \
128+
testing::ValuesIn(std::vector<nixlTestParam>( \
129+
{{_backend, _progress_thread_enabled, _num_workers, _num_threads, _engine_config}})));
130+
107131
#endif /* TEST_GTEST_COMMON_H */

test/gtest/test_transfer.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ class MemBuffer : std::shared_ptr<void> {
9696
const size_t size;
9797
};
9898

99-
class TestTransfer :
100-
// Tuple fields are: backend_name, enable_progress_thread, num_workers, num_threads
101-
public testing::TestWithParam<std::tuple<std::string, bool, size_t, size_t>> {
99+
class TestTransfer : public nixl_test_t {
102100
protected:
103101
nixlAgentConfig
104102
getConfig(int listen_port, bool capture_telemetry) {
@@ -127,6 +125,7 @@ class TestTransfer :
127125
params["split_batch_size"] = "32";
128126
}
129127

128+
params["engine_config"] = GetParam().engineConfig;
130129
return params;
131130
}
132131

@@ -165,22 +164,22 @@ class TestTransfer :
165164

166165
std::string getBackendName() const
167166
{
168-
return std::get<0>(GetParam());
167+
return GetParam().backendName;
169168
}
170169

171170
bool
172171
isProgressThreadEnabled() const {
173-
return std::get<1>(GetParam());
172+
return GetParam().progressThreadEnabled;
174173
}
175174

176175
size_t
177176
getNumWorkers() const {
178-
return std::get<2>(GetParam());
177+
return GetParam().numWorkers;
179178
}
180179

181180
size_t
182181
getNumThreads() const {
183-
return std::get<3>(GetParam());
182+
return GetParam().numThreads;
184183
}
185184

186185
nixl_opt_args_t
@@ -718,14 +717,9 @@ TEST_P(TestTransfer, PrepGpuSignal) {
718717
#endif
719718
}
720719

721-
INSTANTIATE_TEST_SUITE_P(ucx, TestTransfer, testing::Values(std::make_tuple("UCX", true, 2, 0)));
722-
INSTANTIATE_TEST_SUITE_P(ucx_no_pt,
723-
TestTransfer,
724-
testing::Values(std::make_tuple("UCX", false, 2, 0)));
725-
INSTANTIATE_TEST_SUITE_P(ucx_threadpool,
726-
TestTransfer,
727-
testing::Values(std::make_tuple("UCX", true, 6, 4)));
728-
INSTANTIATE_TEST_SUITE_P(ucx_threadpool_no_pt,
729-
TestTransfer,
730-
testing::Values(std::make_tuple("UCX", false, 6, 4)));
720+
NIXL_INSTANTIATE_TEST(ucx, TestTransfer, "UCX", true, 2, 0, "");
721+
NIXL_INSTANTIATE_TEST(ucx_no_pt, TestTransfer, "UCX", false, 2, 0, "");
722+
NIXL_INSTANTIATE_TEST(ucx_threadpool, TestTransfer, "UCX", true, 6, 4, "");
723+
NIXL_INSTANTIATE_TEST(ucx_threadpool_no_pt, TestTransfer, "UCX", false, 6, 4, "");
724+
731725
} // namespace gtest

0 commit comments

Comments
 (0)