Skip to content

Commit 725a1fe

Browse files
committed
Small fix
1 parent 3c7a57d commit 725a1fe

File tree

3 files changed

+25
-47
lines changed

3 files changed

+25
-47
lines changed

tuning/libtuner.py

+10-30
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"""Do not need to change"""
4444

4545
# Declare special symbols for libtuner to search and locate
46-
DEVICE_ID_RE = "DEVICE_ID_RE"
46+
DEVICE_ID_PLACEHOLDER = "!IREE!"
4747

4848

4949
@dataclass
@@ -199,31 +199,15 @@ class IREEBenchmarkResult:
199199
candidate_id: int
200200
result_str: str
201201

202-
def extract_key(self, pattern: str) -> Optional[str]:
202+
def get_mean_time(self) -> Optional[float]:
203203
if not self.result_str:
204204
return None
205+
pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}"
205206
match = re.search(pattern, self.result_str)
206207
if not match:
207208
return None
208-
return match.group(1)
209-
210-
def get_mean_time(self) -> Optional[float]:
211-
pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}"
212-
time_str = self.extract_key(pattern)
213-
if not time_str:
214-
return None
215209
try:
216-
return float(time_str)
217-
except ValueError:
218-
return None
219-
220-
def get_median_time(self) -> Optional[float]:
221-
pattern = r"process_time/real_time_median\s+([\d.]+)\s\w{2}"
222-
time_str = self.extract_key(pattern)
223-
if not time_str:
224-
return None
225-
try:
226-
return float(time_str)
210+
return float(match.group(1))
227211
except ValueError:
228212
return None
229213

@@ -232,7 +216,6 @@ def generate_display_DBR(
232216
candidate_id: int = 0, mean_time: float = random.uniform(100.0, 500.0)
233217
) -> str:
234218
"""Generate dispatch_benchmark_result string for displaying"""
235-
# time unit is implicit and dependent on the output of iree-benchmark-module
236219
return f"{candidate_id}\tMean Time: {mean_time:.1f}\n"
237220

238221

@@ -243,13 +226,12 @@ def generate_display_MBR(
243226
calibrated_diff: Optional[float] = None,
244227
) -> str:
245228
"""Generate model_benchmark_result string for displaying"""
246-
# time unit is implicit and dependent on the output of iree-benchmark-module
247229
head_str = f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}\n"
248-
res_str = f"process_time/real_time_median\t {t1:.3g} ms\n\n"
230+
res_str = f"process_time/real_time_mean\t {t1:.3g} ms\n\n"
249231
if calibrated_diff:
250232
percentage_change = calibrated_diff * 100
251233
change_str = f"({percentage_change:+.3f}%)"
252-
res_str = f"process_time/real_time_median\t {t1:.3g} ms {change_str}\n\n"
234+
res_str = f"process_time/real_time_mean\t {t1:.3g} ms {change_str}\n\n"
253235
return head_str + res_str
254236

255237

@@ -534,7 +516,7 @@ def run_command_wrapper(task_tuple: TaskPack) -> TaskResult:
534516
"""pool.imap_unordered can't iterate an iterable of iterables input, this function helps dividing arguments"""
535517
if task_tuple.command_need_device_id:
536518
# worker searches for special symbol and substitute to correct device_id
537-
pattern = re.compile(re.escape(DEVICE_ID_RE))
519+
pattern = re.compile(re.escape(DEVICE_ID_PLACEHOLDER))
538520
task_tuple.command = [
539521
pattern.sub(str(device_id), s) for s in task_tuple.command
540522
]
@@ -907,14 +889,12 @@ def generate_dryrun_model_benchmark_results(
907889
) -> tuple[list[TaskResult], list[TaskResult]]:
908890
candidate_results = []
909891
for i, j in enumerate(model_candidates):
910-
stdout = (
911-
f"process_time/real_time_median {random.uniform(100.0, 500.0):.3g} ms"
912-
)
892+
stdout = f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms"
913893
candidate_results.append(generate_sample_task_result(stdout, j, str(i % 3)))
914894

915895
baseline_results = [
916896
generate_sample_task_result(
917-
f"process_time/real_time_median {random.uniform(100.0, 500.0):.3g} ms",
897+
f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms",
918898
0,
919899
str(i),
920900
)
@@ -1142,7 +1122,7 @@ def parse_model_benchmark_results(
11421122
continue
11431123

11441124
res = IREEBenchmarkResult(candidate_id, result_str)
1145-
benchmark_time = res.get_median_time()
1125+
benchmark_time = res.get_mean_time()
11461126

11471127
# Check completion
11481128
if benchmark_time == None:

tuning/punet_autotune.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_dispatch_benchmark_command(
5151
"timeout",
5252
"16s",
5353
"./tools/iree-benchmark-module",
54-
f"--device={libtuner.DEVICE_ID_RE}",
54+
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
5555
f"--module={compiled_vmfb_path.resolve()}",
5656
"--hip_use_streams=true",
5757
"--hip_allow_inline_execution=true",
@@ -83,7 +83,7 @@ def get_model_benchmark_command(
8383
"timeout",
8484
"180s",
8585
"tools/iree-benchmark-module",
86-
f"--device={libtuner.DEVICE_ID_RE}",
86+
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
8787
"--hip_use_streams=true",
8888
"--hip_allow_inline_execution=true",
8989
"--device_allocator=caching",

tuning/test_libtuner.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,25 @@ def test_IREEBenchmarkResult_get():
8383
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 274 us 275 us 3000 items_per_second=3.65481k/s
8484
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 273 us 275 us 3000 items_per_second=3.65671k/s
8585
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 274 us 275 us 3 items_per_second=3.65587k/s
86-
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_median 275 us 275 us 3 items_per_second=3.65611k/s
86+
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 275 us 275 us 3 items_per_second=3.65611k/s
8787
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_stddev 0.073 us 0.179 us 3 items_per_second=0.971769/s
8888
BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_cv 0.03 % 0.07 % 3 items_per_second=0.03%
8989
"""
9090
res = libtuner.IREEBenchmarkResult(candidate_id=1, result_str=normal_str)
9191
assert res.get_mean_time() == float(274)
92-
assert res.get_median_time() == float(275)
9392

9493
# Time is float
9594
res = libtuner.IREEBenchmarkResult(
9695
candidate_id=2,
97-
result_str="process_time/real_time_mean 123.45 us, process_time/real_time_median 246.78 us",
96+
result_str="process_time/real_time_mean 123.45 us, process_time/real_time_mean 246.78 us",
9897
)
9998
assert res.get_mean_time() == 123.45
100-
assert res.get_median_time() == 246.78
10199

102100
# Invalid str
103101
res = libtuner.IREEBenchmarkResult(candidate_id=3, result_str="hello world")
104102
assert res.get_mean_time() == None
105-
assert res.get_median_time() == None
106103
res = libtuner.IREEBenchmarkResult(candidate_id=4, result_str="")
107104
assert res.get_mean_time() == None
108-
assert res.get_median_time() == None
109105

110106

111107
def test_generate_display_BR():
@@ -114,13 +110,13 @@ def test_generate_display_BR():
114110
assert output == expected, "DispatchBenchmarkResult generates invalid sample string"
115111

116112
output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89)
117-
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_median\t 568 ms\n\n"
113+
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_mean\t 568 ms\n\n"
118114
assert output == expected, "ModelBenchmarkResult generates invalid sample string"
119115
output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, 0.0314)
120-
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_median\t 568 ms (+3.140%)\n\n"
116+
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_mean\t 568 ms (+3.140%)\n\n"
121117
assert output == expected, "ModelBenchmarkResult generates invalid sample string"
122118
output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, -3.14)
123-
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_median\t 568 ms (-314.000%)\n\n"
119+
expected = "Benchmarking: baseline.vmfb on device 1\nprocess_time/real_time_mean\t 568 ms (-314.000%)\n\n"
124120
assert output == expected, "ModelBenchmarkResult generates invalid sample string"
125121

126122

@@ -225,13 +221,11 @@ def test_parse_model_benchmark_results():
225221
baseline_results = [result3, result4]
226222

227223
# Mock IREEBenchmarkResult to return float value from stdout
228-
def mock_get_median_time(self):
224+
def mock_get_mean_time(self):
229225
return float(self.result_str)
230226

231227
# Mock IREEBenchmarkResult to return specific benchmark times
232-
with patch(
233-
"libtuner.IREEBenchmarkResult.get_median_time", new=mock_get_median_time
234-
):
228+
with patch("libtuner.IREEBenchmarkResult.get_mean_time", new=mock_get_mean_time):
235229
# Mock generate_display_MBR to return a fixed display string
236230
with patch(
237231
"libtuner.generate_display_MBR",
@@ -248,12 +242,16 @@ def mock_get_median_time(self):
248242
assert tracker1.model_benchmark_time == 1.23
249243
assert tracker1.model_benchmark_device_id == "device1"
250244
assert tracker1.baseline_benchmark_time == 0.98
251-
assert tracker1.calibrated_benchmark_diff == (1.23 - 0.98) / 0.98
245+
assert tracker1.calibrated_benchmark_diff == pytest.approx(
246+
(1.23 - 0.98) / 0.98, rel=1e-6
247+
)
252248

253249
assert tracker2.model_benchmark_time == 4.56
254250
assert tracker2.model_benchmark_device_id == "device2"
255251
assert tracker2.baseline_benchmark_time == 4.13
256-
assert tracker2.calibrated_benchmark_diff == (4.56 - 4.13) / 4.13
252+
assert tracker2.calibrated_benchmark_diff == pytest.approx(
253+
(4.56 - 4.13) / 4.13, rel=1e-6
254+
)
257255

258256
assert result == [
259257
"display_str",

0 commit comments

Comments
 (0)