Skip to content

Commit 91fbb46

Browse files
committed
Support long context dataset accuracy measurement
1 parent b8ad727 commit 91fbb46

File tree

4 files changed

+217
-16
lines changed

4 files changed

+217
-16
lines changed

benchmarks/benchmark_serving.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171

7272
from benchmarks.eval_accuracy import eval_accuracy
7373
from benchmarks.eval_accuracy_mmlu import eval_accuracy_mmlu
74+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
7475
from benchmarks.metrics import CounterMetric, EventMetric
7576
import grpc
7677
from jetstream.core.proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166167
output: str = ""
167168
output_len: int = 0
168169
sample_idx: int = -1
170+
metric: str = ""
169171

170172

171173
@dataclass
@@ -187,10 +189,12 @@ def to_dict(self):
187189
prompt = self.input_request.prompt
188190
original_output = self.input_request.output
189191
sample_idx = self.input_request.sample_idx
192+
metric = self.input_request.metric
190193
else:
191194
prompt = None
192195
original_output = None
193196
sample_idx = None
197+
metric = None
194198
return {
195199
"prompt": prompt,
196200
"original_output": original_output,
@@ -201,6 +205,7 @@ def to_dict(self):
201205
"ttst_sec": self.ttst_sec,
202206
"prompt_len": self.prompt_len,
203207
"sample_idx": sample_idx,
208+
"metric": metric,
204209
}
205210

206211

@@ -281,17 +286,19 @@ def load_openorca_dataset_pkl(
281286

282287
def load_longcontext_dataset_pkl(
283288
dataset_path: str,
284-
) -> list[tuple[Any, Any]]:
289+
) -> tuple[list[tuple[Any, Any]], list]:
285290
assert os.path.isfile(dataset_path)
286291

287292
# read pickle file
288293
data = pandas.read_pickle(dataset_path)
289294

290295
samples = []
296+
metrics = []
291297
for _, row in data.iterrows():
292-
samples.append((row["input"], row["ref_output"]))
298+
samples.append((row["input"], row["gt_output"]))
299+
metrics.append(row["metric"])
293300

294-
return samples
301+
return samples, metrics
295302

296303

297304
def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
@@ -417,7 +424,6 @@ def filter_dataset(
417424
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
418425
dataset_type: str,
419426
max_output_length: int = 0,
420-
run_mmlu_dataset: bool = False,
421427
min_input_length: int = 4,
422428
max_input_length: int = 0,
423429
max_target_length: int = 0,
@@ -439,7 +445,8 @@ def filter_dataset(
439445
sample_idx,
440446
) in tokenized_dataset:
441447
if prompt_len < min_input_length or (
442-
not (run_mmlu_dataset or dataset_type == "math500") and output_len < 4
448+
not (dataset_type == "mmlu" or dataset_type == "math500")
449+
and output_len < 4
443450
):
444451
# Prune too short sequences.
445452
# This is because TGI causes errors when the input or output length
@@ -474,11 +481,11 @@ def sample_requests(
474481
dataset_type: str,
475482
max_output_length: int = 0,
476483
oversample_multiplier: float = 1.2,
477-
run_mmlu_dataset: bool = False,
478484
min_input_length: int = 4,
479485
max_input_length: int = 0,
480486
max_target_length: int = 0,
481487
max_output_multiplier: int = 0,
488+
metrics: Optional[list[str]] = None,
482489
) -> list[InputRequest]:
483490

484491
# Original dataset size
@@ -514,13 +521,16 @@ def sample_requests(
514521
tokenized_dataset,
515522
dataset_type,
516523
max_output_length,
517-
run_mmlu_dataset,
518524
min_input_length,
519525
max_input_length,
520526
max_target_length,
521527
max_output_multiplier,
522528
)
523529

530+
if metrics is not None:
531+
for request in input_requests:
532+
request.metric = metrics[request.sample_idx]
533+
524534
# Sample the requests.
525535
if len(input_requests) > num_requests:
526536
input_requests = random.sample(input_requests, num_requests)
@@ -1035,11 +1045,6 @@ def parse_args() -> argparse.Namespace:
10351045
choices=["HELM", "Harness", ""],
10361046
help="mmlu method/format to generate shots",
10371047
)
1038-
parser.add_argument(
1039-
"--run-mmlu-dataset",
1040-
action="store_true",
1041-
help="specify if it's for mmlu dataset",
1042-
)
10431048
return parser.parse_args()
10441049

10451050

@@ -1058,6 +1063,7 @@ def main(args: argparse.Namespace):
10581063
api_url = f"{args.server}:{args.port}"
10591064

10601065
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
1066+
metrics = None
10611067
if tokenizer == "test" or args.dataset == "test":
10621068
input_requests = mock_requests(
10631069
args.total_mock_requests
@@ -1080,7 +1086,7 @@ def main(args: argparse.Namespace):
10801086
args.dataset_path,
10811087
)
10821088
elif args.dataset == "longcontext":
1083-
dataset = load_longcontext_dataset_pkl(
1089+
dataset, metrics = load_longcontext_dataset_pkl(
10841090
args.dataset_path,
10851091
)
10861092
else:
@@ -1097,11 +1103,11 @@ def main(args: argparse.Namespace):
10971103
num_requests=args.num_prompts,
10981104
dataset_type=args.dataset,
10991105
max_output_length=args.max_output_length,
1100-
run_mmlu_dataset=args.run_mmlu_dataset,
11011106
min_input_length=args.min_input_length,
11021107
max_input_length=args.max_input_length,
11031108
max_target_length=args.max_target_length,
11041109
max_output_multiplier=args.max_output_multiplier,
1110+
metrics=metrics,
11051111
)
11061112

11071113
warmup_requests = None
@@ -1145,8 +1151,10 @@ def main(args: argparse.Namespace):
11451151
# Process output
11461152
output = [output.to_dict() for output in request_outputs]
11471153
if args.run_eval:
1148-
if args.run_mmlu_dataset:
1154+
if args.dataset == "mmlu":
11491155
eval_json = eval_accuracy_mmlu(output)
1156+
elif args.dataset == "longcontext":
1157+
eval_json = eval_accuracy_longcontext(output)
11501158
else:
11511159
eval_json = eval_accuracy(output, args.dataset[:4])
11521160

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Evaluate accuracy of JetStream online serving only for long context dataset.
17+
"""
18+
19+
import argparse
20+
import nltk
21+
from tqdm import tqdm
22+
import pandas as pd
23+
import json
24+
import re
25+
from multiprocessing import Pool, cpu_count
26+
import numpy as np
27+
from rouge_score import rouge_scorer
28+
29+
30+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
31+
32+
33+
def rouge(label, pred):
34+
score = scorer.score(label, pred)
35+
return {
36+
"rougeL": 100 * score["rougeL"].fmeasure,
37+
}
38+
39+
40+
def niah_em(label, pred):
41+
label_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", label)
42+
pred_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", pred)
43+
44+
if len(pred_uuids) == 0:
45+
return {"exact_match": 0.0}
46+
47+
# https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py#L28
48+
score = (
49+
sum(
50+
[
51+
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
52+
/ len(ref)
53+
for pred, ref in zip(pred_uuids, label_uuids)
54+
]
55+
)
56+
/ len(pred_uuids)
57+
* 100
58+
)
59+
60+
return {"exact_match": round(score, 2)}
61+
62+
63+
def qa_em(label, pred):
64+
answer_substring = pred
65+
66+
if "Answer: " in pred:
67+
last_answer_index = pred.rfind("Answer: ")
68+
if last_answer_index == -1:
69+
return {"exact_match": 0.0}
70+
71+
answer_substring = pred[last_answer_index + len("Answer: ") :]
72+
73+
if answer_substring in label:
74+
return {"exact_match": 100.0}
75+
76+
normalized_answer = re.sub(r"\s+", "", answer_substring).lower()
77+
label_entries = [
78+
re.sub(r"\s+", "", entry).lower() for entry in label.split("|")
79+
]
80+
81+
match_found = any(entry in normalized_answer for entry in label_entries)
82+
return {"exact_match": 100.0 if match_found else 0.0}
83+
84+
85+
def postprocess_text(preds, targets):
86+
preds = [pred.strip() for pred in preds]
87+
targets = [target.strip() for target in targets]
88+
89+
# rougeLSum expects newline after each sentence
90+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
91+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
92+
93+
return preds, targets
94+
95+
96+
def process_item(item):
97+
pred, target, metric = item
98+
if metric == 'rouge':
99+
metric_eval = rouge(target, pred)
100+
elif metric == 'niah_em':
101+
metric_eval = niah_em(target, pred)
102+
elif metric == 'qa_em':
103+
metric_eval = qa_em(target, pred)
104+
else:
105+
raise ValueError(f"Unknown metric: {metric}")
106+
return metric_eval
107+
108+
109+
def run_evaluation(preds, targets, target_metrics, n_process=None):
110+
n_process = cpu_count() if n_process is None else n_process
111+
with Pool(n_process) as pool:
112+
accuracies = list(
113+
tqdm(
114+
pool.imap(process_item, zip(preds, targets, target_metrics)),
115+
total=len(preds),
116+
)
117+
)
118+
df = pd.DataFrame({"accuracy": accuracies, "metric": target_metrics})
119+
return df.accuracy.apply(pd.Series).describe().loc["mean"].to_dict()
120+
121+
122+
def eval_accuracy_longcontext(request_outputs_dict):
123+
nltk.download("punkt")
124+
preds = []
125+
targets = []
126+
target_metrics = []
127+
for output in request_outputs_dict:
128+
preds.append(output["generated_text"])
129+
targets.append(output["original_output"])
130+
target_metrics.append(output["metric"])
131+
preds, targets = postprocess_text(preds, targets)
132+
result = run_evaluation(preds, targets, target_metrics)
133+
result = dict(result)
134+
prediction_lens = [len(pred) for pred in preds]
135+
result["gen_len"] = int(np.sum(prediction_lens))
136+
result["gen_num"] = len(preds)
137+
print("\nResults\n")
138+
print(result)
139+
return result
140+
141+
142+
def main(args):
143+
with open(args.output_path, "r", encoding="utf-8") as f:
144+
request_outputs_dict = json.load(f)
145+
146+
eval_accuracy_longcontext(request_outputs_dict)
147+
148+
149+
if __name__ == "__main__":
150+
parser = argparse.ArgumentParser()
151+
parser.add_argument(
152+
"--output_path",
153+
type=str,
154+
default="/tmp/request-outputs.json",
155+
help="File path which has original_output and inference generated_text.",
156+
)
157+
158+
parsed_args = parser.parse_args()
159+
160+
main(parsed_args)

benchmarks/requirements.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
nltk
1+
nltk==3.8.1
22
evaluate
33
rouge-score
44
transformers
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Tests for long context accuracy measurement."""
2+
3+
import unittest
4+
5+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
6+
7+
8+
class TestEvalAccuracy(unittest.TestCase):
9+
"""Tests for long context accuracy measurement."""
10+
11+
def setUp(self):
12+
self._request_outputs_dict = [
13+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
14+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
15+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
16+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
17+
{
18+
"generated_text": "abc",
19+
"original_output": "abc",
20+
"metric": "niah_em",
21+
},
22+
{
23+
"generated_text": "abc",
24+
"original_output": "abc",
25+
"metric": "niah_em",
26+
},
27+
]
28+
29+
def test_eval_accuracy_longcontext(self):
30+
self.assertEqual(
31+
eval_accuracy_longcontext(self._request_outputs_dict),
32+
{"rougeL": 100.0, "exact_match": 50.0, "gen_len": 18, "gen_num": 6},
33+
)

0 commit comments

Comments
 (0)