Skip to content

Commit c6554e1

Browse files
committed
Merge branch 'main' of github.com:google/JetStream
2 parents eb9acc1 + b9d4ecf commit c6554e1

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

benchmarks/benchmark_serving.py

+65-11
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ class InputRequest:
9191
@dataclass
9292
class RequestFuncOutput:
9393
input_request: InputRequest = None
94-
generated_text: str = ""
94+
generated_token_list: list[str] = None
95+
generated_text: str = None
9596
success: bool = False
9697
latency: float = 0
9798
ttft: float = 0
@@ -124,6 +125,7 @@ def sample_requests(
124125
dataset_path: str,
125126
num_requests: int,
126127
tokenizer: Any,
128+
max_output_length: int,
127129
) -> List[InputRequest]:
128130
# Load the dataset.
129131
with open(dataset_path) as f:
@@ -167,7 +169,7 @@ def sample_requests(
167169
if prompt_len > 1024 or prompt_len + output_len > 2048:
168170
# Prune too long sequences.
169171
continue
170-
reqeust = InputRequest(prompt, prompt_len, output, output_len)
172+
reqeust = InputRequest(prompt, prompt_len, output, max_output_length)
171173
filtered_dataset.append(reqeust)
172174

173175
# Sample the requests.
@@ -206,9 +208,9 @@ def calculate_metrics(
206208
for i in range(len(outputs)):
207209
if outputs[i].success:
208210
output_len = len(
209-
tokenizer.tokenize(outputs[i].generated_text)
211+
outputs[i].generated_token_list
210212
if tokenizer != "test"
211-
else "ĊŌƟ"
213+
else ["Ċ", "Ō", "Ɵ"]
212214
)
213215
total_output += output_len
214216
total_input += input_requests[i].prompt_len
@@ -234,9 +236,10 @@ def calculate_metrics(
234236
return metrics
235237

236238

237-
def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
239+
def grpc_sync_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
238240
"""Send grpc synchronous request since the current grpc server is sync."""
239-
with grpc.insecure_channel(api_url) as channel:
241+
options = [("grpc.keepalive_timeout_ms", 10000)]
242+
with grpc.insecure_channel(api_url, options=options) as channel:
240243
grpc.channel_ready_future(channel).result()
241244
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
242245
print("Making request")
@@ -249,8 +252,7 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
249252
ttft = time.perf_counter() - request_start_time
250253
token_list.append(token.response[0])
251254
latency = time.perf_counter() - request_start_time
252-
generated_text = "".join(token_list)
253-
return generated_text, ttft, latency
255+
return token_list, ttft, latency
254256

255257

256258
async def send_request(
@@ -273,12 +275,13 @@ async def send_request(
273275
output = RequestFuncOutput()
274276
output.input_request = input_request
275277
output.prompt_len = input_request.prompt_len
276-
generated_text, ttft, latency = await loop.run_in_executor(
278+
generated_token_list, ttft, latency = await loop.run_in_executor(
277279
None, grpc_sync_request, api_url, request
278280
)
279281
output.ttft = ttft
280282
output.latency = latency
281-
output.generated_text = generated_text
283+
output.generated_token_list = generated_token_list
284+
output.generated_text = "".join(generated_token_list)
282285
output.success = True
283286
if pbar:
284287
pbar.update(1)
@@ -374,6 +377,24 @@ def mock_requests(total_mock_requests: int):
374377
return data
375378

376379

380+
def sample_warmup_requests(requests):
381+
interesting_buckets = [
382+
0,
383+
16,
384+
32,
385+
64,
386+
128,
387+
256,
388+
512,
389+
1024,]
390+
391+
for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
392+
for request in requests:
393+
if start < request.prompt_len <= end:
394+
yield request
395+
break
396+
397+
377398
def main(args: argparse.Namespace):
378399
print(args)
379400
random.seed(args.seed)
@@ -388,7 +409,24 @@ def main(args: argparse.Namespace):
388409
if tokenizer == "test" or args.dataset == "test":
389410
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
390411
else:
391-
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
412+
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_output_length)
413+
414+
if args.warmup_first:
415+
print('Warm up start:' )
416+
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
417+
benchmark_result, request_outputs = asyncio.run(
418+
benchmark(
419+
api_url=api_url,
420+
tokenizer=tokenizer,
421+
input_requests=warmup_requests,
422+
request_rate=args.request_rate,
423+
disable_tqdm=args.disable_tqdm,
424+
session_cache=args.session_cache,
425+
priority=args.priority,
426+
threads=args.threads,
427+
)
428+
)
429+
print('Warm up done')
392430

393431
benchmark_result, request_outputs = asyncio.run(
394432
benchmark(
@@ -501,6 +539,14 @@ def main(args: argparse.Namespace):
501539
default=150,
502540
help="The maximum number of mock requests to send for benchmark testing.",
503541
)
542+
543+
parser.add_argument(
544+
"--max-output-length",
545+
type=int,
546+
default=1024,
547+
help="The maximum output length for reference request.",
548+
)
549+
504550
parser.add_argument("--seed", type=int, default=0)
505551
parser.add_argument(
506552
"--disable-tqdm",
@@ -543,6 +589,14 @@ def main(args: argparse.Namespace):
543589
"File path to store request outputs"
544590
),
545591
)
592+
parser.add_argument(
593+
"--warmup-first",
594+
type=bool,
595+
default=False,
596+
help=(
597+
"Whether to send warmup req first"
598+
),
599+
)
546600

547601
args = parser.parse_args()
548602
main(args)

jetstream/core/orchestrator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -617,12 +617,15 @@ def Decode(
617617
'Placed request on the prefill queue.',
618618
)
619619

620-
while True:
620+
while not (
621+
active_request.complete and active_request.return_channel.empty()
622+
):
621623
# When an active request is created a queue is instantiated. New tokens
622624
# are placed there during the decoding loop, we pop from that queue by
623625
# using the .next method on the active request.
624626
# Yielding allows for the response to be a streaming grpc call - which
625627
# can be called via iterating over a for loop on the other side.
628+
# The DecodeResponse stream should consume all generated tokens in
629+
# return_channel when complete signal is received. It should check if
630+
# return_channel is empty to decide if it should exit the while loop.
626631
yield jetstream_pb2.DecodeResponse(response=active_request.next())
627-
if active_request.complete:
628-
break

0 commit comments

Comments
 (0)