Skip to content

Commit 07a443b

Browse files
committed
Support long context dataset accuracy measurement
1 parent b8ad727 commit 07a443b

File tree

3 files changed

+213
-15
lines changed

3 files changed

+213
-15
lines changed

benchmarks/benchmark_serving.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171

7272
from benchmarks.eval_accuracy import eval_accuracy
7373
from benchmarks.eval_accuracy_mmlu import eval_accuracy_mmlu
74+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
7475
from benchmarks.metrics import CounterMetric, EventMetric
7576
import grpc
7677
from jetstream.core.proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166167
output: str = ""
167168
output_len: int = 0
168169
sample_idx: int = -1
170+
metric: str = ""
169171

170172

171173
@dataclass
@@ -187,6 +189,7 @@ def to_dict(self):
187189
prompt = self.input_request.prompt
188190
original_output = self.input_request.output
189191
sample_idx = self.input_request.sample_idx
192+
metric = self.input_request.metric
190193
else:
191194
prompt = None
192195
original_output = None
@@ -201,6 +204,7 @@ def to_dict(self):
201204
"ttst_sec": self.ttst_sec,
202205
"prompt_len": self.prompt_len,
203206
"sample_idx": sample_idx,
207+
"metric": metric,
204208
}
205209

206210

@@ -281,17 +285,19 @@ def load_openorca_dataset_pkl(
281285

282286
def load_longcontext_dataset_pkl(
283287
dataset_path: str,
284-
) -> list[tuple[Any, Any]]:
288+
) -> tuple[list[tuple[Any, Any]], list]:
285289
assert os.path.isfile(dataset_path)
286290

287291
# read pickle file
288292
data = pandas.read_pickle(dataset_path)
289293

290294
samples = []
295+
metrics = []
291296
for _, row in data.iterrows():
292-
samples.append((row["input"], row["ref_output"]))
297+
samples.append((row["input"], row["gt_output"]))
298+
metrics.append(row["metric"])
293299

294-
return samples
300+
return samples, metrics
295301

296302

297303
def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
@@ -417,7 +423,6 @@ def filter_dataset(
417423
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
418424
dataset_type: str,
419425
max_output_length: int = 0,
420-
run_mmlu_dataset: bool = False,
421426
min_input_length: int = 4,
422427
max_input_length: int = 0,
423428
max_target_length: int = 0,
@@ -439,7 +444,8 @@ def filter_dataset(
439444
sample_idx,
440445
) in tokenized_dataset:
441446
if prompt_len < min_input_length or (
442-
not (run_mmlu_dataset or dataset_type == "math500") and output_len < 4
447+
not (dataset_type == "mmlu" or dataset_type == "math500")
448+
and output_len < 4
443449
):
444450
# Prune too short sequences.
445451
# This is because TGI causes errors when the input or output length
@@ -474,11 +480,11 @@ def sample_requests(
474480
dataset_type: str,
475481
max_output_length: int = 0,
476482
oversample_multiplier: float = 1.2,
477-
run_mmlu_dataset: bool = False,
478483
min_input_length: int = 4,
479484
max_input_length: int = 0,
480485
max_target_length: int = 0,
481486
max_output_multiplier: int = 0,
487+
metrics: Optional[list[str]] = None,
482488
) -> list[InputRequest]:
483489

484490
# Original dataset size
@@ -514,13 +520,16 @@ def sample_requests(
514520
tokenized_dataset,
515521
dataset_type,
516522
max_output_length,
517-
run_mmlu_dataset,
518523
min_input_length,
519524
max_input_length,
520525
max_target_length,
521526
max_output_multiplier,
522527
)
523528

529+
if metrics is not None:
530+
for request in input_requests:
531+
request.metric = metrics[request.sample_idx]
532+
524533
# Sample the requests.
525534
if len(input_requests) > num_requests:
526535
input_requests = random.sample(input_requests, num_requests)
@@ -1035,11 +1044,6 @@ def parse_args() -> argparse.Namespace:
10351044
choices=["HELM", "Harness", ""],
10361045
help="mmlu method/format to generate shots",
10371046
)
1038-
parser.add_argument(
1039-
"--run-mmlu-dataset",
1040-
action="store_true",
1041-
help="specify if it's for mmlu dataset",
1042-
)
10431047
return parser.parse_args()
10441048

10451049

@@ -1058,6 +1062,7 @@ def main(args: argparse.Namespace):
10581062
api_url = f"{args.server}:{args.port}"
10591063

10601064
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
1065+
metrics = None
10611066
if tokenizer == "test" or args.dataset == "test":
10621067
input_requests = mock_requests(
10631068
args.total_mock_requests
@@ -1080,7 +1085,7 @@ def main(args: argparse.Namespace):
10801085
args.dataset_path,
10811086
)
10821087
elif args.dataset == "longcontext":
1083-
dataset = load_longcontext_dataset_pkl(
1088+
dataset, metrics = load_longcontext_dataset_pkl(
10841089
args.dataset_path,
10851090
)
10861091
else:
@@ -1097,11 +1102,11 @@ def main(args: argparse.Namespace):
10971102
num_requests=args.num_prompts,
10981103
dataset_type=args.dataset,
10991104
max_output_length=args.max_output_length,
1100-
run_mmlu_dataset=args.run_mmlu_dataset,
11011105
min_input_length=args.min_input_length,
11021106
max_input_length=args.max_input_length,
11031107
max_target_length=args.max_target_length,
11041108
max_output_multiplier=args.max_output_multiplier,
1109+
metrics=metrics,
11051110
)
11061111

11071112
warmup_requests = None
@@ -1145,8 +1150,10 @@ def main(args: argparse.Namespace):
11451150
# Process output
11461151
output = [output.to_dict() for output in request_outputs]
11471152
if args.run_eval:
1148-
if args.run_mmlu_dataset:
1153+
if args.dataset == "mmlu":
11491154
eval_json = eval_accuracy_mmlu(output)
1155+
elif args.dataset == "longcontext":
1156+
eval_json = eval_accuracy_longcontext(output)
11501157
else:
11511158
eval_json = eval_accuracy(output, args.dataset[:4])
11521159

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Evaluate accuracy of JetStream online serving only for long context dataset."""
16+
17+
import argparse
18+
import nltk
19+
import evaluate
20+
from tqdm import tqdm
21+
import pandas as pd
22+
import json
23+
import re
24+
from multiprocessing import Pool, cpu_count
25+
import numpy as np
26+
from rouge_score import rouge_scorer
27+
28+
29+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
30+
31+
32+
def rouge(label, pred):
33+
score = scorer.score(label, pred)
34+
return {
35+
"rougeL": 100 * score["rougeL"].fmeasure,
36+
}
37+
38+
39+
def niah_em(label, pred):
40+
label_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", label)
41+
pred_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", pred)
42+
43+
if len(pred_uuids) == 0:
44+
return {"exact_match": 0.0}
45+
46+
# https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py#L28
47+
score = (
48+
sum(
49+
[
50+
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
51+
/ len(ref)
52+
for pred, ref in zip(pred_uuids, label_uuids)
53+
]
54+
)
55+
/ len(pred_uuids)
56+
* 100
57+
)
58+
59+
return {"exact_match": round(score, 2)}
60+
61+
62+
def qa_em(label, pred):
63+
answer_substring = pred
64+
65+
if "Answer: " in pred:
66+
last_answer_index = pred.rfind("Answer: ")
67+
if last_answer_index == -1:
68+
return {"exact_match": 0.0}
69+
70+
answer_substring = pred[last_answer_index + len("Answer: ") :]
71+
72+
if answer_substring in label:
73+
return {"exact_match": 100.0}
74+
75+
normalized_answer = re.sub(r"\s+", "", answer_substring).lower()
76+
label_entries = [
77+
re.sub(r"\s+", "", entry).lower() for entry in label.split("|")
78+
]
79+
80+
match_found = any(entry in normalized_answer for entry in label_entries)
81+
return {"exact_match": 100.0 if match_found else 0.0}
82+
83+
84+
metrics = {fn.__name__: fn for fn in [rouge, niah_em, qa_em]}
85+
86+
87+
def postprocess_text(preds, targets):
88+
preds = [pred.strip() for pred in preds]
89+
targets = [target.strip() for target in targets]
90+
91+
# rougeLSum expects newline after each sentence
92+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
93+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
94+
95+
return preds, targets
96+
97+
98+
def process_item(item):
99+
pred, target, metric = item
100+
metric_fn = metrics[metric]
101+
metric_eval = metric_fn(target, pred)
102+
return metric_eval
103+
104+
105+
def run_evaluation(preds, targets, metrics, n_process=None):
106+
n_process = cpu_count() if n_process is None else n_process
107+
with Pool(n_process) as pool:
108+
accuracies = list(
109+
tqdm(
110+
pool.imap(process_item, zip(preds, targets, metrics)),
111+
total=len(preds),
112+
)
113+
)
114+
df = pd.DataFrame({"accuracy": accuracies, "metric": metrics})
115+
return df.accuracy.apply(pd.Series).describe().loc["mean"].to_dict()
116+
117+
118+
def eval_accuracy_longcontext(request_outputs_dict):
119+
nltk.download("punkt")
120+
preds = []
121+
targets = []
122+
metrics = []
123+
for output in request_outputs_dict:
124+
preds.append(output["generated_text"])
125+
targets.append(output["original_output"])
126+
metrics.append(output["metric"])
127+
preds, targets = postprocess_text(preds, targets)
128+
result = run_evaluation(preds, targets, metrics)
129+
result = dict(result)
130+
prediction_lens = [len(pred) for pred in preds]
131+
result["gen_len"] = int(np.sum(prediction_lens))
132+
result["gen_num"] = len(preds)
133+
print("\nResults\n")
134+
print(result)
135+
return result
136+
137+
138+
def main(args):
139+
with open(args.output_path, "r", encoding="utf-8") as f:
140+
request_outputs_dict = json.load(f)
141+
142+
eval_accuracy_longcontext(request_outputs_dict)
143+
144+
145+
if __name__ == "__main__":
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument(
148+
"--output_path",
149+
type=str,
150+
default="/tmp/request-outputs.json",
151+
help="File path which has original_output and inference generated_text.",
152+
)
153+
154+
parsed_args = parser.parse_args()
155+
156+
main(parsed_args)
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Tests for long context accuracy measurement."""
2+
3+
import unittest
4+
5+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
6+
import datetime
7+
import re
8+
9+
10+
class TestEvalAccuracy(unittest.TestCase):
11+
"""Tests for long context accuracy measurement."""
12+
13+
def setUp(self):
14+
self._request_outputs_dict = [
15+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
16+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
17+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
18+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
19+
{
20+
"generated_text": "abc",
21+
"original_output": "abc",
22+
"metric": "niah_em",
23+
},
24+
{
25+
"generated_text": "abc",
26+
"original_output": "abc",
27+
"metric": "niah_em",
28+
},
29+
]
30+
31+
def test_eval_accuracy_longcontext(self):
32+
self.assertEqual(
33+
eval_accuracy_longcontext(self._request_outputs_dict),
34+
{"rougeL": 100.0, "exact_match": 50.0, "gen_len": 18, "gen_num": 6},
35+
)

0 commit comments

Comments
 (0)