@@ -81,14 +81,33 @@ class BenchmarkMetrics:
81
81
p99_tpot_ms : float
82
82
83
83
84
+ @dataclass
85
+ class InputRequest :
86
+ prompt : str = ""
87
+ prompt_len : int = 0
88
+ output : str = ""
89
+ output_len : int = 0
90
+
84
91
@dataclass
85
92
class RequestFuncOutput :
93
+ input_request : InputRequest = None
86
94
generated_text : str = ""
87
95
success : bool = False
88
96
latency : float = 0
89
97
ttft : float = 0
90
98
prompt_len : int = 0
91
99
100
+ # Flatten the structure and return only the necessary results
101
+ def to_dict (self ):
102
+ return {
103
+ "prompt" : self .input_request .prompt ,
104
+ "original_output" : self .input_request .output ,
105
+ "generated_text" : self .generated_text ,
106
+ "success" : self .success ,
107
+ "latency" : self .latency ,
108
+ "prompt_len" : self .prompt_len
109
+ }
110
+
92
111
93
112
def get_tokenizer (tokenizer_name : str ) -> Any :
94
113
"""Return a tokenizer or a tokenizer placholder."""
@@ -105,7 +124,7 @@ def sample_requests(
105
124
dataset_path : str ,
106
125
num_requests : int ,
107
126
tokenizer : Any ,
108
- ) -> List [Tuple [ str , int , int ] ]:
127
+ ) -> List [InputRequest ]:
109
128
# Load the dataset.
110
129
with open (dataset_path ) as f :
111
130
dataset = json .load (f )
@@ -133,11 +152,12 @@ def sample_requests(
133
152
tokenized_dataset = []
134
153
for i in range (len (dataset )):
135
154
output_len = len (completion_token_ids [i ])
136
- tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], output_len ))
155
+ tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], completions [ i ], output_len ))
137
156
138
157
# Filter out too long sequences.
139
- filtered_dataset : List [Tuple [str , int , int ]] = []
140
- for prompt , prompt_token_ids , output_len in tokenized_dataset :
158
+ filtered_dataset : List [InputRequest ] = []
159
+
160
+ for prompt , prompt_token_ids , output , output_len in tokenized_dataset :
141
161
prompt_len = len (prompt_token_ids )
142
162
if prompt_len < 4 or output_len < 4 :
143
163
# Prune too short sequences.
@@ -147,17 +167,18 @@ def sample_requests(
147
167
if prompt_len > 1024 or prompt_len + output_len > 2048 :
148
168
# Prune too long sequences.
149
169
continue
150
- filtered_dataset .append ((prompt , prompt_len , output_len ))
170
+ reqeust = InputRequest (prompt , prompt_len , output , output_len )
171
+ filtered_dataset .append (reqeust )
151
172
152
173
# Sample the requests.
153
174
sampled_requests = random .sample (filtered_dataset , num_requests )
154
175
return sampled_requests
155
176
156
177
157
178
async def get_request (
158
- input_requests : List [Tuple [ str , int , int ] ],
179
+ input_requests : List [InputRequest ],
159
180
request_rate : float ,
160
- ) -> AsyncGenerator [Tuple [ str , int , int ] , None ]:
181
+ ) -> AsyncGenerator [InputRequest , None ]:
161
182
input_requests = iter (input_requests )
162
183
for request in input_requests :
163
184
yield request
@@ -172,7 +193,7 @@ async def get_request(
172
193
173
194
174
195
def calculate_metrics (
175
- input_requests : List [Tuple [ str , int , int ] ],
196
+ input_requests : List [InputRequest ],
176
197
outputs : List [RequestFuncOutput ],
177
198
dur_s : float ,
178
199
tokenizer : Any ,
@@ -190,7 +211,7 @@ def calculate_metrics(
190
211
else "ĊŌƟ"
191
212
)
192
213
total_output += output_len
193
- total_input += input_requests [i ][ 1 ]
214
+ total_input += input_requests [i ]. prompt_len
194
215
per_token_latencies .append (outputs [i ].latency / output_len )
195
216
ttfts .append (outputs [i ].ttft )
196
217
completed += 1
@@ -234,25 +255,24 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
234
255
235
256
async def send_request (
236
257
api_url : str ,
237
- prompt : str ,
238
- prompt_len : int ,
258
+ input_request : InputRequest ,
239
259
pbar : tqdm ,
240
260
session_cache : str ,
241
261
priority : int ,
242
- max_tokens : int ,
243
262
threads : int ,
244
263
) -> RequestFuncOutput :
245
264
"""Send the request to JetStream server."""
246
265
loop = asyncio .get_running_loop ()
247
266
loop .set_default_executor (ThreadPoolExecutor (max_workers = threads ))
248
267
request = jetstream_pb2 .DecodeRequest (
249
268
session_cache = session_cache ,
250
- additional_text = prompt ,
269
+ additional_text = input_request . prompt ,
251
270
priority = priority ,
252
- max_tokens = max_tokens ,
271
+ max_tokens = input_request . output_len ,
253
272
)
254
273
output = RequestFuncOutput ()
255
- output .prompt_len = prompt_len
274
+ output .input_request = input_request
275
+ output .prompt_len = input_request .prompt_len
256
276
generated_text , ttft , latency = await loop .run_in_executor (
257
277
None , grpc_sync_request , api_url , request
258
278
)
@@ -268,7 +288,7 @@ async def send_request(
268
288
async def benchmark (
269
289
api_url : str ,
270
290
tokenizer : Any ,
271
- input_requests : List [Tuple [ str , int , int ] ],
291
+ input_requests : List [InputRequest ],
272
292
request_rate : float ,
273
293
disable_tqdm : bool ,
274
294
session_cache : str ,
@@ -283,17 +303,14 @@ async def benchmark(
283
303
benchmark_start_time = time .perf_counter ()
284
304
tasks = []
285
305
async for request in get_request (input_requests , request_rate ):
286
- prompt , prompt_len , output_len = request
287
306
tasks .append (
288
307
asyncio .create_task (
289
308
send_request (
290
309
api_url = api_url ,
291
- prompt = prompt ,
292
- prompt_len = prompt_len ,
310
+ input_request = request ,
293
311
pbar = pbar ,
294
312
session_cache = session_cache ,
295
313
priority = priority ,
296
- max_tokens = output_len ,
297
314
threads = threads ,
298
315
)
299
316
)
@@ -341,17 +358,19 @@ async def benchmark(
341
358
"median_tpot_ms" : metrics .median_tpot_ms ,
342
359
"p99_tpot_ms" : metrics .p99_tpot_ms ,
343
360
}
344
- return result
361
+ return result , outputs
345
362
346
363
347
364
def mock_requests (total_mock_requests : int ):
348
365
"""Generates a list of mock requests containing mock data."""
349
366
data = []
350
367
for _ in range (total_mock_requests ):
351
- name = f"Item { random .randint (1 , 1000 )} "
352
- price = random .randint (10 , 100 )
353
- quantity = random .randint (1 , 10 )
354
- data .append ((name , price , quantity ))
368
+ reqeust = InputRequest ()
369
+ reqeust .prompt = f"Prompt { random .randint (1 , 1000 )} "
370
+ reqeust .prompt_len = random .randint (10 , 100 )
371
+ reqeust .out = f"Output { random .randint (1 , 1000 )} "
372
+ reqeust .output_len = random .randint (1 , 10 )
373
+ data .append (reqeust )
355
374
return data
356
375
357
376
@@ -367,11 +386,11 @@ def main(args: argparse.Namespace):
367
386
368
387
tokenizer = get_tokenizer (tokenizer_id )
369
388
if tokenizer == "test" or args .dataset == "test" :
370
- input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, 3)]
389
+ input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
371
390
else :
372
391
input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
373
392
374
- benchmark_result = asyncio .run (
393
+ benchmark_result , request_outputs = asyncio .run (
375
394
benchmark (
376
395
api_url = api_url ,
377
396
tokenizer = tokenizer ,
@@ -411,6 +430,11 @@ def main(args: argparse.Namespace):
411
430
with open (file_name , "w" ) as outfile :
412
431
json .dump (result_json , outfile )
413
432
433
+ if args .save_request_outputs :
434
+ file_path = args .request_outputs_file_path
435
+ with open (file_path , "w" ) as output_file :
436
+ json .dump ([output .to_dict () for output in request_outputs ], output_file , indent = 4 )
437
+
414
438
415
439
if __name__ == "__main__" :
416
440
parser = argparse .ArgumentParser (
@@ -506,6 +530,19 @@ def main(args: argparse.Namespace):
506
530
" not implemented, use default empty str)"
507
531
),
508
532
)
533
+ parser .add_argument (
534
+ "--save-request-outputs" ,
535
+ action = "store_true" ,
536
+ help = "Specify to store request outputs into a json file" ,
537
+ )
538
+ parser .add_argument (
539
+ "--request-outputs-file-path" ,
540
+ type = str ,
541
+ default = "/tmp/request-outputs.json" ,
542
+ help = (
543
+ "File path to store request outputs"
544
+ ),
545
+ )
509
546
510
547
args = parser .parse_args ()
511
548
main (args )
0 commit comments