Skip to content

Commit 095e846

Browse files
authored
Merge pull request #2574 from bratpiorka/rrudnick_fix_usm_pool_config_parse
fix parseDisjointPoolConfig and add tests
2 parents 0bb6789 + 07001aa commit 095e846

File tree

2 files changed

+90
-35
lines changed

2 files changed

+90
-35
lines changed

source/common/umf_pools/disjoint_pool_config_parser.cpp

+24-35
Original file line numberDiff line numberDiff line change
@@ -174,47 +174,36 @@ DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
174174
MemParser(Params, M);
175175
};
176176

177-
size_t MaxSize = (std::numeric_limits<size_t>::max)();
178-
179177
// Update pool settings if specified in environment.
178+
size_t MaxSize = (std::numeric_limits<size_t>::max)();
180179
size_t EnableBuffers = 1;
181-
if (config != "") {
182-
std::string Params = config;
183-
size_t Pos = Params.find(';');
184-
if (Pos != std::string::npos) {
185-
if (Pos > 0) {
186-
GetValue(Params, Pos, EnableBuffers);
180+
181+
bool EnableBuffersSet = false;
182+
bool MaxSizeSet = false;
183+
size_t Start = 0;
184+
size_t End = config.find(';');
185+
while (true) {
186+
std::string Param = config.substr(Start, End - Start);
187+
if (!EnableBuffersSet && (Param == "" || isdigit(Param[0]))) {
188+
if (Param != "") {
189+
GetValue(Param, Param.size(), EnableBuffers);
187190
}
188-
Params.erase(0, Pos + 1);
189-
size_t Pos = Params.find(';');
190-
if (Pos != std::string::npos) {
191-
if (Pos > 0) {
192-
GetValue(Params, Pos, MaxSize);
193-
}
194-
Params.erase(0, Pos + 1);
195-
do {
196-
size_t Pos = Params.find(';');
197-
if (Pos != std::string::npos) {
198-
if (Pos > 0) {
199-
std::string MemParams = Params.substr(0, Pos);
200-
MemTypeParser(MemParams);
201-
}
202-
Params.erase(0, Pos + 1);
203-
if (Params.size() == 0) {
204-
break;
205-
}
206-
} else {
207-
MemTypeParser(Params);
208-
break;
209-
}
210-
} while (true);
211-
} else {
212-
// set MaxPoolSize for all configs
213-
GetValue(Params, Params.size(), MaxSize);
191+
EnableBuffersSet = true;
192+
} else if (!MaxSizeSet && (Param == "" || isdigit(Param[0]))) {
193+
if (Param != "") {
194+
GetValue(Param, Param.size(), MaxSize);
214195
}
196+
MaxSizeSet = true;
215197
} else {
216-
GetValue(Params, Params.size(), EnableBuffers);
198+
MemTypeParser(Param);
217199
}
200+
201+
if (End == std::string::npos) {
202+
break;
203+
}
204+
205+
Start = End + 1;
206+
End = config.find(';', Start);
218207
}
219208

220209
AllConfigs.EnableBuffers = EnableBuffers;

test/usm/usmPoolManager.cpp

+66
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "umf_pools/disjoint_pool_config_parser.hpp"
78
#include "ur_pool_manager.hpp"
89

910
#include <uur/fixtures.h>
@@ -18,6 +19,26 @@ auto createMockPoolHandle() {
1819
[](umf_memory_pool_t *) {});
1920
}
2021

22+
bool compareConfig(const usm::umf_disjoint_pool_config_t &left,
23+
usm::umf_disjoint_pool_config_t &right) {
24+
return left.MaxPoolableSize == right.MaxPoolableSize &&
25+
left.Capacity == right.Capacity &&
26+
left.SlabMinSize == right.SlabMinSize;
27+
}
28+
29+
bool compareConfigs(const usm::DisjointPoolAllConfigs &left,
30+
usm::DisjointPoolAllConfigs &right) {
31+
return left.EnableBuffers == right.EnableBuffers &&
32+
compareConfig(left.Configs[usm::DisjointPoolMemType::Host],
33+
right.Configs[usm::DisjointPoolMemType::Host]) &&
34+
compareConfig(left.Configs[usm::DisjointPoolMemType::Device],
35+
right.Configs[usm::DisjointPoolMemType::Device]) &&
36+
compareConfig(left.Configs[usm::DisjointPoolMemType::Shared],
37+
right.Configs[usm::DisjointPoolMemType::Shared]) &&
38+
compareConfig(left.Configs[usm::DisjointPoolMemType::SharedReadOnly],
39+
right.Configs[usm::DisjointPoolMemType::SharedReadOnly]);
40+
}
41+
2142
TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
2243
auto &devices = uur::DevicesEnvironment::instance->devices;
2344

@@ -111,4 +132,49 @@ TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
111132
}
112133
}
113134

135+
TEST_P(urUsmPoolManagerTest, config) {
136+
// Check default config
137+
usm::DisjointPoolAllConfigs def;
138+
usm::DisjointPoolAllConfigs parsed1 =
139+
usm::parseDisjointPoolConfig("1;host:2M,4,64K;device:4M,4,64K;"
140+
"shared:0,0,2M;read_only_shared:4M,4,2M",
141+
0);
142+
ASSERT_EQ(compareConfigs(def, parsed1), true);
143+
144+
// Check partially set config
145+
usm::DisjointPoolAllConfigs part1 =
146+
usm::parseDisjointPoolConfig("1;device:4M;shared:0,0,2M", 0);
147+
ASSERT_EQ(compareConfigs(def, part1), true);
148+
149+
// Check partially set config #2
150+
usm::DisjointPoolAllConfigs part2 =
151+
usm::parseDisjointPoolConfig(";device:4M;shared:0,0,2M", 0);
152+
ASSERT_EQ(compareConfigs(def, part2), true);
153+
154+
// Check partially set config #3
155+
usm::DisjointPoolAllConfigs part3 =
156+
usm::parseDisjointPoolConfig(";shared:0,0,2M", 0);
157+
ASSERT_EQ(compareConfigs(def, part3), true);
158+
159+
// Check partially set config #4
160+
usm::DisjointPoolAllConfigs part4 =
161+
usm::parseDisjointPoolConfig(";device:4M", 0);
162+
ASSERT_EQ(compareConfigs(def, part4), true);
163+
164+
// Check partially set config #5
165+
usm::DisjointPoolAllConfigs part5 =
166+
usm::parseDisjointPoolConfig(";;device:4M,4,64K", 0);
167+
ASSERT_EQ(compareConfigs(def, part5), true);
168+
169+
// Check non-default config
170+
usm::DisjointPoolAllConfigs test(def);
171+
test.Configs[usm::DisjointPoolMemType::Shared].MaxPoolableSize = 128 * 1024;
172+
test.Configs[usm::DisjointPoolMemType::Shared].Capacity = 4;
173+
test.Configs[usm::DisjointPoolMemType::Shared].SlabMinSize = 64 * 1024;
174+
175+
usm::DisjointPoolAllConfigs parsed3 =
176+
usm::parseDisjointPoolConfig("1;shared:128K,4,64K", 0);
177+
ASSERT_EQ(compareConfigs(test, parsed3), true);
178+
}
179+
114180
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);

0 commit comments

Comments
 (0)