Skip to content

Commit bdece3b

Browse files
Merge pull request #781 from mlcommons/dev
Dev -> main
2 parents 2d1ac6f + 3b832f4 commit bdece3b

File tree

4 files changed

+137
-3
lines changed

4 files changed

+137
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ algorithmic_efficiency/workloads/librispeech_conformer/work_dir
2020
*.vocab
2121
wandb/
2222
*.txt
23+
scoring/plots/
2324

2425
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
2526
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

scoring/compute_speedups.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""File to compute speedups (i.e. geometric means between runtimes)."""
2+
3+
import pickle
4+
5+
from absl import app
6+
from absl import flags
7+
import numpy as np
8+
import pandas as pd
9+
from performance_profile import BASE_WORKLOADS
10+
from performance_profile import get_workloads_time_to_target
11+
from scipy import stats
12+
13+
flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.')
14+
flags.DEFINE_string(
15+
'base',
16+
'prize_qualification_baseline',
17+
'Base submission to compare to. Defaults to the `prize_qualification_baseline`.'
18+
)
19+
flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.')
20+
flags.DEFINE_boolean('self_tuning_ruleset',
21+
False,
22+
'Whether the self-tuning ruleset is being scored.')
23+
flags.DEFINE_boolean('save_results',
24+
False,
25+
'Whether to save the results to disk.')
26+
FLAGS = flags.FLAGS
27+
28+
MAX_BUDGETS = {
29+
'criteo1tb': 7703,
30+
'fastmri': 8859,
31+
'imagenet_resnet': 63_008,
32+
'imagenet_vit': 77_520,
33+
'librispeech_conformer': 61_068,
34+
'librispeech_deepspeech': 55_506,
35+
'ogbg': 18_477,
36+
'wmt': 48_151,
37+
}
38+
39+
40+
def replace_inf(row):
41+
"""Replace ifs with maximum runtime budget (+1 second).
42+
43+
Args:
44+
row (pd.Series): The original row.
45+
46+
Returns:
47+
pd.Series: The row with infs replaced.
48+
"""
49+
workload_name = row.name
50+
# Factor of 3 for self-tuning ruleset
51+
factor = 3 if FLAGS.self_tuning_ruleset else 1
52+
max_runtime_workload = factor * MAX_BUDGETS[workload_name]
53+
row.replace(np.inf, max_runtime_workload + 1, inplace=True)
54+
return row
55+
56+
57+
def compute_speedup():
58+
"""Compute speedup between two algorithms."""
59+
# Load results from disk
60+
with open(FLAGS.results_txt, 'rb') as f:
61+
results = pickle.load(f)
62+
63+
# Compute median over runtimes for both training algorithms
64+
base_results = get_workloads_time_to_target(
65+
results[FLAGS.base],
66+
FLAGS.base,
67+
time_col="score",
68+
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
69+
)
70+
comparison_results = get_workloads_time_to_target(
71+
results[FLAGS.comparison],
72+
FLAGS.comparison,
73+
time_col="score",
74+
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
75+
)
76+
77+
# Merge results
78+
merged_results = pd.concat([base_results, comparison_results]).transpose()
79+
80+
# Ignore workload variants (only consider base workloads) for speedup
81+
merged_results = merged_results.loc[merged_results.index.isin(BASE_WORKLOADS)]
82+
83+
# Replace infs with maximum runtime budget (+1 second)
84+
merged_results = merged_results.apply(replace_inf, axis=1)
85+
86+
# Compute speedup
87+
merged_results['speedup'] = merged_results[
88+
f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
89+
speedups = merged_results['speedup'].to_numpy()
90+
mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups
91+
92+
print(merged_results, end='\n\n')
93+
print(
94+
f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}"
95+
)
96+
97+
if FLAGS.save_results:
98+
# Optionally save results to disk
99+
print("Saving results to disk...")
100+
filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv'
101+
merged_results.to_csv(filename)
102+
103+
104+
def main(_):
105+
"""Main function to compute speedup between two algorithms."""
106+
compute_speedup()
107+
108+
109+
if __name__ == '__main__':
110+
flags.mark_flag_as_required('results_txt')
111+
flags.mark_flag_as_required('comparison')
112+
app.run(main)

scoring/performance_profile.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
the dictionary of submissions.
2727
"""
2828
import itertools
29+
import json
2930
import operator
3031
import os
3132
import re
@@ -45,6 +46,10 @@
4546
BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS
4647
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
4748
BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
49+
# Open json file to read heldout workloads
50+
# TODO: This probably shouldn't be hardcoded but passed as an argument.
51+
with open("held_out_workloads_algoperf_v05.json", "r") as f:
52+
HELDOUT_WORKLOADS = json.load(f)
4853
# These global variables have to be set according to the current set of
4954
# workloads and rules for the scoring to be correct.
5055
# We do not use the workload registry since it contains test and development
@@ -248,6 +253,9 @@ def filter(x):
248253
try:
249254
if x[variant_workload] == np.inf:
250255
return np.inf
256+
# Also check for nan values (e.g. OOMs)
257+
elif np.isnan(x[variant_workload]):
258+
return np.inf
251259
else:
252260
return x[base_workload]
253261
except KeyError as e:
@@ -306,19 +314,32 @@ def compute_performance_profiles(submissions,
306314
self_tuning_ruleset,
307315
strict))
308316
df = pd.concat(dfs)
317+
# Restrict to base and sampled held-out workloads
318+
# (ignore the additional workload variants of the baseline
319+
# as they cause issues when checking for nans in workload variants).
320+
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
321+
# Sort workloads alphabetically (for better display)
322+
df = df.reindex(sorted(df.columns), axis=1)
323+
324+
# For each held-out workload set to inf if the base workload is inf or nan
325+
for workload in df.keys():
326+
if workload not in BASE_WORKLOADS:
327+
# If base do not have finite score set variant score to inf
328+
base_workload = get_base_workload_name(workload)
329+
df[workload] = df.apply(
330+
variant_criteria_filter(workload, base_workload), axis=1)
309331

310332
# Set score to inf if not within 4x of fastest submission
311333
best_scores = df.min(axis=0)
312334
df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf
313335

314-
# For each held-out workload if variant target was not hit set submission to inf
336+
# For each base workload if variant target was not hit set submission to inf
315337
for workload in df.keys():
316338
if workload not in BASE_WORKLOADS:
317339
# If variants do not have finite score set base_workload score to inf
318340
base_workload = get_base_workload_name(workload)
319341
df[base_workload] = df.apply(
320342
variant_criteria_filter(base_workload, workload), axis=1)
321-
322343
df = df[BASE_WORKLOADS]
323344

324345
if verbosity > 0:

scoring/score_submissions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def main(_):
198198
results,
199199
time_col='score',
200200
min_tau=1.0,
201-
max_tau=None,
201+
max_tau=4.0,
202202
reference_submission_tag=None,
203203
num_points=100,
204204
scale='linear',

0 commit comments

Comments
 (0)