Skip to content

Commit e04302c

Browse files
authored
Add option to warm up (#11)
1 parent 2b9db52 commit e04302c

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

benchmarks/benchmark_serving.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def calculate_metrics(
237237

238238
def grpc_sync_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
239239
"""Send grpc synchronous request since the current grpc server is sync."""
240-
with grpc.insecure_channel(api_url) as channel:
240+
options = [("grpc.keepalive_timeout_ms", 10000)]
241+
with grpc.insecure_channel(api_url, options=options) as channel:
241242
grpc.channel_ready_future(channel).result()
242243
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
243244
print("Making request")
@@ -374,6 +375,24 @@ def mock_requests(total_mock_requests: int):
374375
return data
375376

376377

378+
def sample_warmup_requests(requests):
379+
interesting_buckets = [
380+
0,
381+
16,
382+
32,
383+
64,
384+
128,
385+
256,
386+
512,
387+
1024,]
388+
389+
for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
390+
for request in requests:
391+
if start < request.prompt_len <= end:
392+
yield request
393+
break
394+
395+
377396
def main(args: argparse.Namespace):
378397
print(args)
379398
random.seed(args.seed)
@@ -390,6 +409,23 @@ def main(args: argparse.Namespace):
390409
else:
391410
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_output_length)
392411

412+
if args.warmup_first:
413+
print('Warm up start:' )
414+
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
415+
benchmark_result, request_outputs = asyncio.run(
416+
benchmark(
417+
api_url=api_url,
418+
tokenizer=tokenizer,
419+
input_requests=warmup_requests,
420+
request_rate=args.request_rate,
421+
disable_tqdm=args.disable_tqdm,
422+
session_cache=args.session_cache,
423+
priority=args.priority,
424+
threads=args.threads,
425+
)
426+
)
427+
print('Warm up done')
428+
393429
benchmark_result, request_outputs = asyncio.run(
394430
benchmark(
395431
api_url=api_url,
@@ -551,6 +587,14 @@ def main(args: argparse.Namespace):
551587
"File path to store request outputs"
552588
),
553589
)
590+
parser.add_argument(
591+
"--warmup-first",
592+
type=bool,
593+
default=False,
594+
help=(
595+
"Whether to send warmup req first"
596+
),
597+
)
554598

555599
args = parser.parse_args()
556600
main(args)

0 commit comments

Comments
 (0)