@@ -91,7 +91,8 @@ class InputRequest:
91
91
@dataclass
92
92
class RequestFuncOutput :
93
93
input_request : InputRequest = None
94
- generated_text : str = ""
94
+ generated_token_list : list [str ] = None
95
+ generated_text : str = None
95
96
success : bool = False
96
97
latency : float = 0
97
98
ttft : float = 0
@@ -124,6 +125,7 @@ def sample_requests(
124
125
dataset_path : str ,
125
126
num_requests : int ,
126
127
tokenizer : Any ,
128
+ max_output_length : int ,
127
129
) -> List [InputRequest ]:
128
130
# Load the dataset.
129
131
with open (dataset_path ) as f :
@@ -167,7 +169,7 @@ def sample_requests(
167
169
if prompt_len > 1024 or prompt_len + output_len > 2048 :
168
170
# Prune too long sequences.
169
171
continue
170
- reqeust = InputRequest (prompt , prompt_len , output , output_len )
172
+ reqeust = InputRequest (prompt , prompt_len , output , max_output_length )
171
173
filtered_dataset .append (reqeust )
172
174
173
175
# Sample the requests.
@@ -206,9 +208,9 @@ def calculate_metrics(
206
208
for i in range (len (outputs )):
207
209
if outputs [i ].success :
208
210
output_len = len (
209
- tokenizer . tokenize ( outputs [i ].generated_text )
211
+ outputs [i ].generated_token_list
210
212
if tokenizer != "test"
211
- else "ĊŌƟ"
213
+ else [ "Ċ" , "Ō" , "Ɵ" ]
212
214
)
213
215
total_output += output_len
214
216
total_input += input_requests [i ].prompt_len
@@ -234,9 +236,10 @@ def calculate_metrics(
234
236
return metrics
235
237
236
238
237
- def grpc_sync_request (api_url : str , request : Any ) -> tuple [str , float , float ]:
239
+ def grpc_sync_request (api_url : str , request : Any ) -> tuple [list [ str ] , float , float ]:
238
240
"""Send grpc synchronous request since the current grpc server is sync."""
239
- with grpc .insecure_channel (api_url ) as channel :
241
+ options = [("grpc.keepalive_timeout_ms" , 10000 )]
242
+ with grpc .insecure_channel (api_url , options = options ) as channel :
240
243
grpc .channel_ready_future (channel ).result ()
241
244
stub = jetstream_pb2_grpc .OrchestratorStub (channel )
242
245
print ("Making request" )
@@ -249,8 +252,7 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
249
252
ttft = time .perf_counter () - request_start_time
250
253
token_list .append (token .response [0 ])
251
254
latency = time .perf_counter () - request_start_time
252
- generated_text = "" .join (token_list )
253
- return generated_text , ttft , latency
255
+ return token_list , ttft , latency
254
256
255
257
256
258
async def send_request (
@@ -273,12 +275,13 @@ async def send_request(
273
275
output = RequestFuncOutput ()
274
276
output .input_request = input_request
275
277
output .prompt_len = input_request .prompt_len
276
- generated_text , ttft , latency = await loop .run_in_executor (
278
+ generated_token_list , ttft , latency = await loop .run_in_executor (
277
279
None , grpc_sync_request , api_url , request
278
280
)
279
281
output .ttft = ttft
280
282
output .latency = latency
281
- output .generated_text = generated_text
283
+ output .generated_token_list = generated_token_list
284
+ output .generated_text = "" .join (generated_token_list )
282
285
output .success = True
283
286
if pbar :
284
287
pbar .update (1 )
@@ -374,6 +377,24 @@ def mock_requests(total_mock_requests: int):
374
377
return data
375
378
376
379
380
+ def sample_warmup_requests (requests ):
381
+ interesting_buckets = [
382
+ 0 ,
383
+ 16 ,
384
+ 32 ,
385
+ 64 ,
386
+ 128 ,
387
+ 256 ,
388
+ 512 ,
389
+ 1024 ,]
390
+
391
+ for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
392
+ for request in requests :
393
+ if start < request .prompt_len <= end :
394
+ yield request
395
+ break
396
+
397
+
377
398
def main (args : argparse .Namespace ):
378
399
print (args )
379
400
random .seed (args .seed )
@@ -388,7 +409,24 @@ def main(args: argparse.Namespace):
388
409
if tokenizer == "test" or args .dataset == "test" :
389
410
input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
390
411
else :
391
- input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
412
+ input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args .max_output_length )
413
+
414
+ if args .warmup_first :
415
+ print ('Warm up start:' )
416
+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
417
+ benchmark_result , request_outputs = asyncio .run (
418
+ benchmark (
419
+ api_url = api_url ,
420
+ tokenizer = tokenizer ,
421
+ input_requests = warmup_requests ,
422
+ request_rate = args .request_rate ,
423
+ disable_tqdm = args .disable_tqdm ,
424
+ session_cache = args .session_cache ,
425
+ priority = args .priority ,
426
+ threads = args .threads ,
427
+ )
428
+ )
429
+ print ('Warm up done' )
392
430
393
431
benchmark_result , request_outputs = asyncio .run (
394
432
benchmark (
@@ -501,6 +539,14 @@ def main(args: argparse.Namespace):
501
539
default = 150 ,
502
540
help = "The maximum number of mock requests to send for benchmark testing." ,
503
541
)
542
+
543
+ parser .add_argument (
544
+ "--max-output-length" ,
545
+ type = int ,
546
+ default = 1024 ,
547
+ help = "The maximum output length for reference request." ,
548
+ )
549
+
504
550
parser .add_argument ("--seed" , type = int , default = 0 )
505
551
parser .add_argument (
506
552
"--disable-tqdm" ,
@@ -543,6 +589,14 @@ def main(args: argparse.Namespace):
543
589
"File path to store request outputs"
544
590
),
545
591
)
592
+ parser .add_argument (
593
+ "--warmup-first" ,
594
+ type = bool ,
595
+ default = False ,
596
+ help = (
597
+ "Whether to send warmup req first"
598
+ ),
599
+ )
546
600
547
601
args = parser .parse_args ()
548
602
main (args )
0 commit comments