Skip to content

Commit 131aae9

Browse files
committed
add prompt logprobs
1 parent fa09838 commit 131aae9

File tree

9 files changed

+759
-9
lines changed

9 files changed

+759
-9
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def __init__(
227227
self.think_end_id = args.get("think_end_id", -1)
228228
self.im_patch_id = args.get("image_patch_id", -1)
229229
self.line_break_id = args.get("line_break_id", -1)
230-
if self.max_logprobs == -1 and hasattr(self, "vocab_size"):
231-
self.max_logprobs = self.vocab_size
230+
if self.max_logprobs < -1 and self.max_logprobs > self.ori_vocab_size:
231+
raise ValueError(" The possible values for max_logprobs are -1 and [0, vocab_size] ")
232232

233233
self._post_init()
234234

fastdeploy/engine/request.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
from fastdeploy.engine.sampling_params import SamplingParams
3030
from fastdeploy.entrypoints.openai.protocol import ToolCall
3131
from fastdeploy.utils import data_processor_logger
32-
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
32+
from fastdeploy.worker.output import (
33+
LogprobsLists,
34+
LogprobsTensors,
35+
PromptLogprobs,
36+
SampleLogprobs,
37+
)
3338

3439

3540
class RequestStatus(Enum):
@@ -462,6 +467,8 @@ def __init__(
462467
request_id: str,
463468
prompt: Optional[str] = None,
464469
prompt_token_ids: Optional[list[int]] = None,
470+
prompt_logprobs: Optional[PromptLogprobs] = None,
471+
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
465472
output_type: Optional[int] = 3,
466473
outputs: CompletionOutput = None,
467474
finished: bool = False,
@@ -475,6 +482,8 @@ def __init__(
475482
self.request_id = request_id
476483
self.prompt = prompt
477484
self.prompt_token_ids = prompt_token_ids
485+
self.prompt_logprobs = prompt_logprobs
486+
self.prompt_logprobs_tensors = prompt_logprobs_tensors
478487
self.output_type = output_type
479488
self.outputs = outputs
480489
self.finished = finished
@@ -520,6 +529,7 @@ def __repr__(self) -> str:
520529
f"RequestOutput(request_id={self.request_id}, "
521530
f"prompt={self.prompt!r}, "
522531
f"prompt_token_ids={self.prompt_token_ids}, "
532+
f"prompt_logprobs={self.prompt_logprobs}, "
523533
f"output_type={self.output_type}, "
524534
f"outputs={self.outputs}, "
525535
f"finished={self.finished}, "
@@ -545,6 +555,7 @@ def to_dict(self):
545555
"request_id": self.request_id,
546556
"prompt": self.prompt,
547557
"prompt_token_ids": self.prompt_token_ids,
558+
"prompt_logprobs": self.prompt_logprobs,
548559
"output_type": self.output_type,
549560
"outputs": None if self.outputs is None else self.outputs.to_dict(),
550561
"metrics": None if self.metrics is None else self.metrics.to_dict(),

fastdeploy/engine/sampling_params.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import os
1920
import random
2021
from dataclasses import dataclass, fields
2122
from enum import Enum
@@ -204,10 +205,12 @@ def _verify_args(self) -> None:
204205
raise ValueError(
205206
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
206207
)
207-
if self.logprobs is not None and self.logprobs < 0:
208-
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
209-
if self.logprobs is not None and self.logprobs > 20:
208+
if self.logprobs is not None and self.logprobs < -1:
209+
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
210+
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
210211
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
212+
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
213+
raise ValueError(f"prompt_logprobs must can't be less than -1, got {self.prompt_logprobs}.")
211214

212215
if not 0 <= self.seed <= 922337203685477580:
213216
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")

fastdeploy/entrypoints/llm.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
from __future__ import annotations
1818

19+
import itertools
1920
import logging
2021
import threading
2122
import time
2223
import traceback
2324
import uuid
25+
from collections.abc import Iterable
2426
from typing import Any, Optional, Union
2527

2628
from pydantic import ValidationError
@@ -37,13 +39,20 @@
3739
llm_logger,
3840
retrive_model_from_server,
3941
)
40-
from fastdeploy.worker.output import Logprob, LogprobsLists
42+
from fastdeploy.worker.output import (
43+
Logprob,
44+
LogprobsLists,
45+
LogprobsTensors,
46+
PromptLogprobs,
47+
)
4148

4249
root_logger = logging.getLogger()
4350
for handler in root_logger.handlers[:]:
4451
if isinstance(handler, logging.StreamHandler):
4552
root_logger.removeHandler(handler)
4653

54+
NONES = itertools.repeat(None)
55+
4756

4857
class LLM:
4958
"""
@@ -189,12 +198,17 @@ def generate(
189198
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
190199

191200
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
201+
num_prompt_logprobs = (
202+
sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs
203+
)
192204

193205
# get output
194206
if stream:
195207
return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
196208
else:
197-
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
209+
outputs = self._run_engine(
210+
req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs
211+
)
198212
for i in range(len(outputs)):
199213
outputs[i].prompt = prompts[i]
200214
return outputs
@@ -295,6 +309,28 @@ def _add_request(
295309
if prompts is None:
296310
raise ValueError("prompts and prompt_ids cannot be both None.")
297311

312+
if kwargs.get("stream") and sampling_params.prompt_logprobs is not None:
313+
raise ValueError("prompt_logprobs is not supported with streaming.")
314+
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
315+
if max_logprobs == -1:
316+
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
317+
if sampling_params.logprobs is not None:
318+
num_logprobs = sampling_params.logprobs
319+
if num_logprobs == -1:
320+
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
321+
if num_logprobs > max_logprobs:
322+
raise ValueError(
323+
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
324+
)
325+
if sampling_params.prompt_logprobs is not None:
326+
num_prompt_logprobs = sampling_params.prompt_logprobs
327+
if num_prompt_logprobs == -1:
328+
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
329+
if num_prompt_logprobs > max_logprobs:
330+
raise ValueError(
331+
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
332+
)
333+
298334
prompts_len = len(prompts)
299335
req_ids = []
300336
for i in range(prompts_len):
@@ -377,7 +413,93 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
377413
except Exception as e:
378414
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}")
379415

380-
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
416+
def _build_prompt_logprobs(
417+
self,
418+
prompt_logprobs_tensors: LogprobsTensors,
419+
num_prompt_logprobs: int,
420+
):
421+
"""Update with prompt logprobs from worker.
422+
Args:
423+
prompt_logprobs_tensors: tuple containing the prompt logprobs
424+
tensors.
425+
"""
426+
427+
token_ids, logprobs, ranks = prompt_logprobs_tensors
428+
429+
# Detokenize non-incrementally.
430+
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
431+
decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()]
432+
433+
# Recover shapes.
434+
num_prompt_tokens, num_logprobs = logprobs.shape
435+
436+
# Pythonize the torch tensors.
437+
prompt_token_ranks = ranks.tolist()
438+
prompt_logprobs = logprobs.tolist()
439+
token_ids = token_ids.tolist()
440+
result: Optional[PromptLogprobs] = []
441+
# Make Logprob for each position.
442+
for pos in range(num_prompt_tokens):
443+
# Handle flattening.
444+
offset = pos * num_logprobs
445+
offset_end = offset + num_logprobs
446+
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
447+
448+
# Update with the Logprob dictionary for this pos.
449+
result.append(
450+
self._make_logprob_dict(
451+
prompt_logprobs[pos],
452+
token_ids[pos],
453+
decoded_tokens_for_pos,
454+
prompt_token_ranks[pos],
455+
num_prompt_logprobs,
456+
)
457+
)
458+
return result
459+
460+
@staticmethod
461+
def _make_logprob_dict(
462+
logprobs: list[float],
463+
logprob_token_ids: list[int],
464+
decoded_tokens: Iterable[str | None],
465+
rank: int,
466+
num_logprobs: int,
467+
) -> dict[int, Logprob]:
468+
"""Make a Logprob dictionary for a position.
469+
Args:
470+
logprobs: list of log probabilities
471+
logprob_token_ids: list of top token ids
472+
decoded_tokens: list of decoded top tokens
473+
rank: rank of the sampled token
474+
num_logprobs: number of logprobs requested
475+
by the user (in addition to sampled logprob)
476+
Returns:
477+
dict[token id, Logprob]
478+
"""
479+
if num_logprobs == -1:
480+
num_logprobs = len(logprobs)
481+
# We do not need a special case for the sampled token
482+
# being in the topk, since inserting duplicated data
483+
# into a dictionary twice is the same as doing it once.
484+
topk_ranks = range(1, num_logprobs + 1)
485+
ranks = itertools.chain((rank,), topk_ranks)
486+
487+
return {
488+
token_id: Logprob(
489+
logprob=logprob,
490+
rank=rank,
491+
decoded_token=token,
492+
)
493+
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
494+
}
495+
496+
def _run_engine(
497+
self,
498+
req_ids: list[str],
499+
use_tqdm: bool,
500+
topk_logprobs: Optional[int] = None,
501+
num_prompt_logprobs: Optional[int] = None,
502+
):
381503
"""
382504
运行引擎,并返回结果列表。
383505
@@ -422,9 +544,17 @@ def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optiona
422544

423545
# filter logprobs
424546
if result.outputs.top_logprobs and topk_logprobs:
547+
if topk_logprobs == -1:
548+
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
425549
result.outputs.logprobs = self._build_sample_logprobs(
426550
result.outputs.top_logprobs, topk_logprobs
427551
)
552+
if result.prompt_logprobs_tensors and num_prompt_logprobs:
553+
if num_prompt_logprobs == -1:
554+
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
555+
result.prompt_logprobs = self._build_prompt_logprobs(
556+
result.prompt_logprobs_tensors, num_prompt_logprobs
557+
)
428558

429559
output[pos] = result
430560
finished.append(i)

fastdeploy/output/token_processor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def _process_batch_output_use_zmq(self, receive_datas):
285285
finished=False,
286286
metrics=metrics,
287287
)
288+
if self.use_logprobs:
289+
if getattr(stream_data, "logprobs", None) is not None:
290+
try:
291+
logprobs_list: LogprobsLists = stream_data.logprobs.tolists()
292+
result.outputs.logprob = float(logprobs_list.logprobs[0][0])
293+
result.outputs.top_logprobs = logprobs_list
294+
except Exception as e:
295+
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
296+
if getattr(stream_data, "prompt_logprobs", None) is not None:
297+
try:
298+
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
299+
except Exception as e:
300+
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
288301
if self.tokens_counter[task_id] == 0:
289302
if task.messages is not None:
290303
result.prompt = task.messages

fastdeploy/worker/output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Logprob(NamedTuple):
3030
decoded_token: Optional[str] = None
3131

3232

33+
PromptLogprobs = list[dict[int, Logprob] | None]
3334
# [{token_id, logprob}] for tokens sampled from the top-k
3435
SampleLogprobs = list[dict[int, Logprob]]
3536

0 commit comments

Comments
 (0)