Skip to content

Commit 2b9db52

Browse files
authored
Fix output token drop issue (#9)
* Fix output token drop issue * Add comments and format * Fix benchmarks
1 parent 8289e65 commit 2b9db52

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

benchmarks/benchmark_serving.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class InputRequest:
9191
@dataclass
9292
class RequestFuncOutput:
9393
input_request: InputRequest = None
94-
generated_text: str = ""
94+
generated_token_list: list[str] = None
9595
success: bool = False
9696
latency: float = 0
9797
ttft: float = 0
@@ -102,7 +102,7 @@ def to_dict(self):
102102
return {
103103
"prompt": self.input_request.prompt,
104104
"original_output": self.input_request.output,
105-
"generated_text": self.generated_text,
105+
"generated_token_list": self.generated_token_list,
106106
"success": self.success,
107107
"latency": self.latency,
108108
"prompt_len": self.prompt_len
@@ -207,9 +207,9 @@ def calculate_metrics(
207207
for i in range(len(outputs)):
208208
if outputs[i].success:
209209
output_len = len(
210-
tokenizer.tokenize(outputs[i].generated_text)
210+
outputs[i].generated_token_list
211211
if tokenizer != "test"
212-
else "ĊŌƟ"
212+
else ["Ċ", "Ō", "Ɵ"]
213213
)
214214
total_output += output_len
215215
total_input += input_requests[i].prompt_len
@@ -235,7 +235,7 @@ def calculate_metrics(
235235
return metrics
236236

237237

238-
def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
238+
def grpc_sync_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
239239
"""Send grpc synchronous request since the current grpc server is sync."""
240240
with grpc.insecure_channel(api_url) as channel:
241241
grpc.channel_ready_future(channel).result()
@@ -250,8 +250,7 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
250250
ttft = time.perf_counter() - request_start_time
251251
token_list.append(token.response[0])
252252
latency = time.perf_counter() - request_start_time
253-
generated_text = "".join(token_list)
254-
return generated_text, ttft, latency
253+
return token_list, ttft, latency
255254

256255

257256
async def send_request(
@@ -274,12 +273,12 @@ async def send_request(
274273
output = RequestFuncOutput()
275274
output.input_request = input_request
276275
output.prompt_len = input_request.prompt_len
277-
generated_text, ttft, latency = await loop.run_in_executor(
276+
generated_token_list, ttft, latency = await loop.run_in_executor(
278277
None, grpc_sync_request, api_url, request
279278
)
280279
output.ttft = ttft
281280
output.latency = latency
282-
output.generated_text = generated_text
281+
output.generated_token_list = generated_token_list
283282
output.success = True
284283
if pbar:
285284
pbar.update(1)

jetstream/core/orchestrator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -617,12 +617,15 @@ def Decode(
617617
'Placed request on the prefill queue.',
618618
)
619619

620-
while True:
620+
while not (
621+
active_request.complete and active_request.return_channel.empty()
622+
):
621623
# When an active request is created a queue is instantiated. New tokens
622624
# are placed there during the decoding loop, we pop from that queue by
623625
# using the .next method on the active request.
624626
# Yielding allows for the response to be a streaming grpc call - which
625627
# can be called via iterating over a for loop on the other side.
628+
# The DecodeResponse stream should consume all generated tokens in
629+
# return_channel when complete signal is received. It should check if
630+
# return_channel is empty to decide if it should exit the while loop.
626631
yield jetstream_pb2.DecodeResponse(response=active_request.next())
627-
if active_request.complete:
628-
break

0 commit comments

Comments
 (0)