@@ -237,7 +237,8 @@ def calculate_metrics(
237
237
238
238
def grpc_sync_request (api_url : str , request : Any ) -> tuple [list [str ], float , float ]:
239
239
"""Send grpc synchronous request since the current grpc server is sync."""
240
- with grpc .insecure_channel (api_url ) as channel :
240
+ options = [("grpc.keepalive_timeout_ms" , 10000 )]
241
+ with grpc .insecure_channel (api_url , options = options ) as channel :
241
242
grpc .channel_ready_future (channel ).result ()
242
243
stub = jetstream_pb2_grpc .OrchestratorStub (channel )
243
244
print ("Making request" )
@@ -374,6 +375,24 @@ def mock_requests(total_mock_requests: int):
374
375
return data
375
376
376
377
378
+ def sample_warmup_requests (requests ):
379
+ interesting_buckets = [
380
+ 0 ,
381
+ 16 ,
382
+ 32 ,
383
+ 64 ,
384
+ 128 ,
385
+ 256 ,
386
+ 512 ,
387
+ 1024 ,]
388
+
389
+ for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
390
+ for request in requests :
391
+ if start < request .prompt_len <= end :
392
+ yield request
393
+ break
394
+
395
+
377
396
def main (args : argparse .Namespace ):
378
397
print (args )
379
398
random .seed (args .seed )
@@ -390,6 +409,23 @@ def main(args: argparse.Namespace):
390
409
else :
391
410
input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args .max_output_length )
392
411
412
+ if args .warmup_first :
413
+ print ('Warm up start:' )
414
+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
415
+ benchmark_result , request_outputs = asyncio .run (
416
+ benchmark (
417
+ api_url = api_url ,
418
+ tokenizer = tokenizer ,
419
+ input_requests = warmup_requests ,
420
+ request_rate = args .request_rate ,
421
+ disable_tqdm = args .disable_tqdm ,
422
+ session_cache = args .session_cache ,
423
+ priority = args .priority ,
424
+ threads = args .threads ,
425
+ )
426
+ )
427
+ print ('Warm up done' )
428
+
393
429
benchmark_result , request_outputs = asyncio .run (
394
430
benchmark (
395
431
api_url = api_url ,
@@ -551,6 +587,14 @@ def main(args: argparse.Namespace):
551
587
"File path to store request outputs"
552
588
),
553
589
)
590
+ parser .add_argument (
591
+ "--warmup-first" ,
592
+ type = bool ,
593
+ default = False ,
594
+ help = (
595
+ "Whether to send warmup req first"
596
+ ),
597
+ )
554
598
555
599
args = parser .parse_args ()
556
600
main (args )
0 commit comments