Skip to content

Commit 41ad033

Browse files
authored
Save request outputs and add eval accuracy support (#8)
* Save request outputs and add eval accuracy support * add readme and requirements * add space in readme * add username * replace Benchmarks with Benchmark * fix the path
1 parent 0d7fcf7 commit 41ad033

File tree

4 files changed

+172
-27
lines changed

4 files changed

+172
-27
lines changed

benchmarks/README.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# JetStream Benchmark And Eval
2+
3+
## Install Dependencies
4+
5+
```
6+
cd ~/JetStream/benchmarks
7+
pip install -r requirements.in
8+
```
9+
10+
## Benchmark
11+
12+
### Prepare DataSet
13+
14+
```
15+
cd ~/data
16+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
17+
18+
```
19+
20+
### Run Benchmark with maxtext tokenizer
21+
22+
```
23+
python benchmark_serving.py \
24+
--tokenizer /home/{username}/maxtext/assets/tokenizer \
25+
--num-prompts 10 \
26+
--dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json
27+
28+
```
29+
30+
### Save request outputs in Benchmark
31+
32+
Please use --save-request-outputs flag to enable this feature.
33+
34+
```
35+
python benchmark_serving.py \
36+
--tokenizer /home/{username}/maxtext/assets/tokenizer \
37+
--num-prompts 10 \
38+
--dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
39+
--save-request-outputs
40+
41+
```
42+
43+
## Eval Accuracy
44+
45+
Evaluate inference genereted output accuracy using saved request outputs.
46+
47+
```
48+
python eval_accuracy.py
49+
50+
```

benchmarks/benchmark_serving.py

+64-27
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,33 @@ class BenchmarkMetrics:
8181
p99_tpot_ms: float
8282

8383

84+
@dataclass
85+
class InputRequest:
86+
prompt: str = ""
87+
prompt_len: int = 0
88+
output: str = ""
89+
output_len: int = 0
90+
8491
@dataclass
8592
class RequestFuncOutput:
93+
input_request: InputRequest = None
8694
generated_text: str = ""
8795
success: bool = False
8896
latency: float = 0
8997
ttft: float = 0
9098
prompt_len: int = 0
9199

100+
# Flatten the structure and return only the necessary results
101+
def to_dict(self):
102+
return {
103+
"prompt": self.input_request.prompt,
104+
"original_output": self.input_request.output,
105+
"generated_text": self.generated_text,
106+
"success": self.success,
107+
"latency": self.latency,
108+
"prompt_len": self.prompt_len
109+
}
110+
92111

93112
def get_tokenizer(tokenizer_name: str) -> Any:
94113
"""Return a tokenizer or a tokenizer placholder."""
@@ -105,7 +124,7 @@ def sample_requests(
105124
dataset_path: str,
106125
num_requests: int,
107126
tokenizer: Any,
108-
) -> List[Tuple[str, int, int]]:
127+
) -> List[InputRequest]:
109128
# Load the dataset.
110129
with open(dataset_path) as f:
111130
dataset = json.load(f)
@@ -133,11 +152,12 @@ def sample_requests(
133152
tokenized_dataset = []
134153
for i in range(len(dataset)):
135154
output_len = len(completion_token_ids[i])
136-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
155+
tokenized_dataset.append((prompts[i], prompt_token_ids[i], completions[i], output_len))
137156

138157
# Filter out too long sequences.
139-
filtered_dataset: List[Tuple[str, int, int]] = []
140-
for prompt, prompt_token_ids, output_len in tokenized_dataset:
158+
filtered_dataset: List[InputRequest] = []
159+
160+
for prompt, prompt_token_ids, output, output_len in tokenized_dataset:
141161
prompt_len = len(prompt_token_ids)
142162
if prompt_len < 4 or output_len < 4:
143163
# Prune too short sequences.
@@ -147,17 +167,18 @@ def sample_requests(
147167
if prompt_len > 1024 or prompt_len + output_len > 2048:
148168
# Prune too long sequences.
149169
continue
150-
filtered_dataset.append((prompt, prompt_len, output_len))
170+
reqeust = InputRequest(prompt, prompt_len, output, output_len)
171+
filtered_dataset.append(reqeust)
151172

152173
# Sample the requests.
153174
sampled_requests = random.sample(filtered_dataset, num_requests)
154175
return sampled_requests
155176

156177

157178
async def get_request(
158-
input_requests: List[Tuple[str, int, int]],
179+
input_requests: List[InputRequest],
159180
request_rate: float,
160-
) -> AsyncGenerator[Tuple[str, int, int], None]:
181+
) -> AsyncGenerator[InputRequest, None]:
161182
input_requests = iter(input_requests)
162183
for request in input_requests:
163184
yield request
@@ -172,7 +193,7 @@ async def get_request(
172193

173194

174195
def calculate_metrics(
175-
input_requests: List[Tuple[str, int, int]],
196+
input_requests: List[InputRequest],
176197
outputs: List[RequestFuncOutput],
177198
dur_s: float,
178199
tokenizer: Any,
@@ -190,7 +211,7 @@ def calculate_metrics(
190211
else "ĊŌƟ"
191212
)
192213
total_output += output_len
193-
total_input += input_requests[i][1]
214+
total_input += input_requests[i].prompt_len
194215
per_token_latencies.append(outputs[i].latency / output_len)
195216
ttfts.append(outputs[i].ttft)
196217
completed += 1
@@ -234,25 +255,24 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
234255

235256
async def send_request(
236257
api_url: str,
237-
prompt: str,
238-
prompt_len: int,
258+
input_request: InputRequest,
239259
pbar: tqdm,
240260
session_cache: str,
241261
priority: int,
242-
max_tokens: int,
243262
threads: int,
244263
) -> RequestFuncOutput:
245264
"""Send the request to JetStream server."""
246265
loop = asyncio.get_running_loop()
247266
loop.set_default_executor(ThreadPoolExecutor(max_workers=threads))
248267
request = jetstream_pb2.DecodeRequest(
249268
session_cache=session_cache,
250-
additional_text=prompt,
269+
additional_text=input_request.prompt,
251270
priority=priority,
252-
max_tokens=max_tokens,
271+
max_tokens=input_request.output_len,
253272
)
254273
output = RequestFuncOutput()
255-
output.prompt_len = prompt_len
274+
output.input_request = input_request
275+
output.prompt_len = input_request.prompt_len
256276
generated_text, ttft, latency = await loop.run_in_executor(
257277
None, grpc_sync_request, api_url, request
258278
)
@@ -268,7 +288,7 @@ async def send_request(
268288
async def benchmark(
269289
api_url: str,
270290
tokenizer: Any,
271-
input_requests: List[Tuple[str, int, int]],
291+
input_requests: List[InputRequest],
272292
request_rate: float,
273293
disable_tqdm: bool,
274294
session_cache: str,
@@ -283,17 +303,14 @@ async def benchmark(
283303
benchmark_start_time = time.perf_counter()
284304
tasks = []
285305
async for request in get_request(input_requests, request_rate):
286-
prompt, prompt_len, output_len = request
287306
tasks.append(
288307
asyncio.create_task(
289308
send_request(
290309
api_url=api_url,
291-
prompt=prompt,
292-
prompt_len=prompt_len,
310+
input_request=request,
293311
pbar=pbar,
294312
session_cache=session_cache,
295313
priority=priority,
296-
max_tokens=output_len,
297314
threads=threads,
298315
)
299316
)
@@ -341,17 +358,19 @@ async def benchmark(
341358
"median_tpot_ms": metrics.median_tpot_ms,
342359
"p99_tpot_ms": metrics.p99_tpot_ms,
343360
}
344-
return result
361+
return result, outputs
345362

346363

347364
def mock_requests(total_mock_requests: int):
348365
"""Generates a list of mock requests containing mock data."""
349366
data = []
350367
for _ in range(total_mock_requests):
351-
name = f"Item {random.randint(1, 1000)}"
352-
price = random.randint(10, 100)
353-
quantity = random.randint(1, 10)
354-
data.append((name, price, quantity))
368+
reqeust = InputRequest()
369+
reqeust.prompt = f"Prompt {random.randint(1, 1000)}"
370+
reqeust.prompt_len = random.randint(10, 100)
371+
reqeust.out = f"Output {random.randint(1, 1000)}"
372+
reqeust.output_len = random.randint(1, 10)
373+
data.append(reqeust)
355374
return data
356375

357376

@@ -367,11 +386,11 @@ def main(args: argparse.Namespace):
367386

368387
tokenizer = get_tokenizer(tokenizer_id)
369388
if tokenizer == "test" or args.dataset == "test":
370-
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, 3)]
389+
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
371390
else:
372391
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
373392

374-
benchmark_result = asyncio.run(
393+
benchmark_result, request_outputs = asyncio.run(
375394
benchmark(
376395
api_url=api_url,
377396
tokenizer=tokenizer,
@@ -411,6 +430,11 @@ def main(args: argparse.Namespace):
411430
with open(file_name, "w") as outfile:
412431
json.dump(result_json, outfile)
413432

433+
if args.save_request_outputs:
434+
file_path = args.request_outputs_file_path
435+
with open(file_path, "w") as output_file:
436+
json.dump([output.to_dict() for output in request_outputs], output_file, indent=4)
437+
414438

415439
if __name__ == "__main__":
416440
parser = argparse.ArgumentParser(
@@ -506,6 +530,19 @@ def main(args: argparse.Namespace):
506530
" not implemented, use default empty str)"
507531
),
508532
)
533+
parser.add_argument(
534+
"--save-request-outputs",
535+
action="store_true",
536+
help="Specify to store request outputs into a json file",
537+
)
538+
parser.add_argument(
539+
"--request-outputs-file-path",
540+
type=str,
541+
default="/tmp/request-outputs.json",
542+
help=(
543+
"File path to store request outputs"
544+
),
545+
)
509546

510547
args = parser.parse_args()
511548
main(args)

benchmarks/eval_accuracy.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import argparse
2+
import nltk
3+
import evaluate
4+
import json
5+
6+
import numpy as np
7+
8+
def postprocess_text(preds, targets):
9+
preds = [pred.strip() for pred in preds]
10+
targets = [target.strip() for target in targets]
11+
12+
# rougeLSum expects newline after each sentence
13+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
14+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
15+
16+
return preds, targets
17+
18+
19+
def eval_accuracy(request_outputs_dict):
20+
metric = evaluate.load("rouge")
21+
nltk.download('punkt')
22+
preds = []
23+
targets = []
24+
25+
for output in request_outputs_dict:
26+
preds.append(output["generated_text"])
27+
targets.append(output["original_output"])
28+
preds, targets = postprocess_text(preds, targets)
29+
result = metric.compute(
30+
predictions=preds, references=targets, use_stemmer=True, use_aggregator=False)
31+
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
32+
prediction_lens = [len(pred) for pred in preds]
33+
result["gen_len"] = np.sum(prediction_lens)
34+
result["gen_num"] = len(preds)
35+
print("\nResults\n")
36+
print(result)
37+
38+
39+
def main(args):
40+
with open(args.output_path) as f:
41+
request_outputs_dict = json.load(f)
42+
43+
eval_accuracy(request_outputs_dict)
44+
45+
46+
if __name__ == '__main__':
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument(
49+
"--output_path", type=str,
50+
default="/tmp/request-outputs.json",
51+
help="File path which has original_output and inference generated_text.")
52+
53+
args = parser.parse_args()
54+
55+
main(args)

benchmarks/requirements.in

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
nltk
2+
evaluate
3+
rouge-score

0 commit comments

Comments
 (0)