Skip to content

Commit 2c2b9c2

Browse files
committed
add prompt logprobs
1 parent fa09838 commit 2c2b9c2

File tree

9 files changed

+758
-9
lines changed

9 files changed

+758
-9
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def __init__(
227227
self.think_end_id = args.get("think_end_id", -1)
228228
self.im_patch_id = args.get("image_patch_id", -1)
229229
self.line_break_id = args.get("line_break_id", -1)
230-
if self.max_logprobs == -1 and hasattr(self, "vocab_size"):
231-
self.max_logprobs = self.vocab_size
230+
if self.max_logprobs < -1 and self.max_logprobs > self.ori_vocab_size:
231+
raise ValueError(" The possible values for max_logprobs are -1 and [0, vocab_size] ")
232232

233233
self._post_init()
234234

fastdeploy/engine/request.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
from fastdeploy.engine.sampling_params import SamplingParams
3030
from fastdeploy.entrypoints.openai.protocol import ToolCall
3131
from fastdeploy.utils import data_processor_logger
32-
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
32+
from fastdeploy.worker.output import (
33+
LogprobsLists,
34+
LogprobsTensors,
35+
PromptLogprobs,
36+
SampleLogprobs,
37+
)
3338

3439

3540
class RequestStatus(Enum):
@@ -462,6 +467,8 @@ def __init__(
462467
request_id: str,
463468
prompt: Optional[str] = None,
464469
prompt_token_ids: Optional[list[int]] = None,
470+
prompt_logprobs: Optional[PromptLogprobs] = None,
471+
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
465472
output_type: Optional[int] = 3,
466473
outputs: CompletionOutput = None,
467474
finished: bool = False,
@@ -475,6 +482,8 @@ def __init__(
475482
self.request_id = request_id
476483
self.prompt = prompt
477484
self.prompt_token_ids = prompt_token_ids
485+
self.prompt_logprobs = prompt_logprobs
486+
self.prompt_logprobs_tensors = prompt_logprobs_tensors
478487
self.output_type = output_type
479488
self.outputs = outputs
480489
self.finished = finished
@@ -520,6 +529,7 @@ def __repr__(self) -> str:
520529
f"RequestOutput(request_id={self.request_id}, "
521530
f"prompt={self.prompt!r}, "
522531
f"prompt_token_ids={self.prompt_token_ids}, "
532+
f"prompt_logprobs={self.prompt_logprobs}, "
523533
f"output_type={self.output_type}, "
524534
f"outputs={self.outputs}, "
525535
f"finished={self.finished}, "
@@ -545,6 +555,7 @@ def to_dict(self):
545555
"request_id": self.request_id,
546556
"prompt": self.prompt,
547557
"prompt_token_ids": self.prompt_token_ids,
558+
"prompt_logprobs": self.prompt_logprobs,
548559
"output_type": self.output_type,
549560
"outputs": None if self.outputs is None else self.outputs.to_dict(),
550561
"metrics": None if self.metrics is None else self.metrics.to_dict(),

fastdeploy/engine/sampling_params.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import os
1920
import random
2021
from dataclasses import dataclass, fields
2122
from enum import Enum
@@ -204,10 +205,12 @@ def _verify_args(self) -> None:
204205
raise ValueError(
205206
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
206207
)
207-
if self.logprobs is not None and self.logprobs < 0:
208-
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
209-
if self.logprobs is not None and self.logprobs > 20:
208+
if self.logprobs is not None and self.logprobs < -1:
209+
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
210+
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
210211
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
212+
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
213+
raise ValueError(f"prompt_logprobs must can't be less than -1, got {self.prompt_logprobs}.")
211214

212215
if not 0 <= self.seed <= 922337203685477580:
213216
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")

fastdeploy/entrypoints/llm.py

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
from __future__ import annotations
1818

19+
import itertools
1920
import logging
2021
import threading
2122
import time
2223
import traceback
2324
import uuid
25+
from collections.abc import Iterable
2426
from typing import Any, Optional, Union
2527

2628
from pydantic import ValidationError
@@ -37,13 +39,20 @@
3739
llm_logger,
3840
retrive_model_from_server,
3941
)
40-
from fastdeploy.worker.output import Logprob, LogprobsLists
42+
from fastdeploy.worker.output import (
43+
Logprob,
44+
LogprobsLists,
45+
LogprobsTensors,
46+
PromptLogprobs,
47+
)
4148

4249
root_logger = logging.getLogger()
4350
for handler in root_logger.handlers[:]:
4451
if isinstance(handler, logging.StreamHandler):
4552
root_logger.removeHandler(handler)
4653

54+
NONES = itertools.repeat(None)
55+
4756

4857
class LLM:
4958
"""
@@ -189,12 +198,17 @@ def generate(
189198
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
190199

191200
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
201+
num_prompt_logprobs = (
202+
sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs
203+
)
192204

193205
# get output
194206
if stream:
195207
return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
196208
else:
197-
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
209+
outputs = self._run_engine(
210+
req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs
211+
)
198212
for i in range(len(outputs)):
199213
outputs[i].prompt = prompts[i]
200214
return outputs
@@ -321,6 +335,27 @@ def _add_request(
321335
current_sampling_params = sampling_params[i]
322336
else:
323337
current_sampling_params = sampling_params
338+
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
339+
raise ValueError("prompt_logprobs is not supported with streaming.")
340+
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
341+
if max_logprobs == -1:
342+
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
343+
if current_sampling_params.logprobs is not None:
344+
num_logprobs = current_sampling_params.logprobs
345+
if num_logprobs == -1:
346+
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
347+
if num_logprobs > max_logprobs:
348+
raise ValueError(
349+
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
350+
)
351+
if current_sampling_params.prompt_logprobs is not None:
352+
num_prompt_logprobs = current_sampling_params.prompt_logprobs
353+
if num_prompt_logprobs == -1:
354+
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
355+
if num_prompt_logprobs > max_logprobs:
356+
raise ValueError(
357+
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
358+
)
324359
if current_sampling_params.guided_decoding is not None:
325360
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
326361
tasks.update(guided_decoding_dict)
@@ -377,7 +412,93 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
377412
except Exception as e:
378413
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}")
379414

380-
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
415+
def _build_prompt_logprobs(
416+
self,
417+
prompt_logprobs_tensors: LogprobsTensors,
418+
num_prompt_logprobs: int,
419+
):
420+
"""Update with prompt logprobs from worker.
421+
Args:
422+
prompt_logprobs_tensors: tuple containing the prompt logprobs
423+
tensors.
424+
"""
425+
426+
token_ids, logprobs, ranks = prompt_logprobs_tensors
427+
428+
# Detokenize non-incrementally.
429+
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
430+
decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()]
431+
432+
# Recover shapes.
433+
num_prompt_tokens, num_logprobs = logprobs.shape
434+
435+
# Pythonize the torch tensors.
436+
prompt_token_ranks = ranks.tolist()
437+
prompt_logprobs = logprobs.tolist()
438+
token_ids = token_ids.tolist()
439+
result: Optional[PromptLogprobs] = []
440+
# Make Logprob for each position.
441+
for pos in range(num_prompt_tokens):
442+
# Handle flattening.
443+
offset = pos * num_logprobs
444+
offset_end = offset + num_logprobs
445+
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
446+
447+
# Update with the Logprob dictionary for this pos.
448+
result.append(
449+
self._make_logprob_dict(
450+
prompt_logprobs[pos],
451+
token_ids[pos],
452+
decoded_tokens_for_pos,
453+
prompt_token_ranks[pos],
454+
num_prompt_logprobs,
455+
)
456+
)
457+
return result
458+
459+
@staticmethod
460+
def _make_logprob_dict(
461+
logprobs: list[float],
462+
logprob_token_ids: list[int],
463+
decoded_tokens: Iterable[str | None],
464+
rank: int,
465+
num_logprobs: int,
466+
) -> dict[int, Logprob]:
467+
"""Make a Logprob dictionary for a position.
468+
Args:
469+
logprobs: list of log probabilities
470+
logprob_token_ids: list of top token ids
471+
decoded_tokens: list of decoded top tokens
472+
rank: rank of the sampled token
473+
num_logprobs: number of logprobs requested
474+
by the user (in addition to sampled logprob)
475+
Returns:
476+
dict[token id, Logprob]
477+
"""
478+
if num_logprobs == -1:
479+
num_logprobs = len(logprobs)
480+
# We do not need a special case for the sampled token
481+
# being in the topk, since inserting duplicated data
482+
# into a dictionary twice is the same as doing it once.
483+
topk_ranks = range(1, num_logprobs + 1)
484+
ranks = itertools.chain((rank,), topk_ranks)
485+
486+
return {
487+
token_id: Logprob(
488+
logprob=logprob,
489+
rank=rank,
490+
decoded_token=token,
491+
)
492+
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
493+
}
494+
495+
def _run_engine(
496+
self,
497+
req_ids: list[str],
498+
use_tqdm: bool,
499+
topk_logprobs: Optional[int] = None,
500+
num_prompt_logprobs: Optional[int] = None,
501+
):
381502
"""
382503
运行引擎,并返回结果列表。
383504
@@ -422,9 +543,17 @@ def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optiona
422543

423544
# filter logprobs
424545
if result.outputs.top_logprobs and topk_logprobs:
546+
if topk_logprobs == -1:
547+
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
425548
result.outputs.logprobs = self._build_sample_logprobs(
426549
result.outputs.top_logprobs, topk_logprobs
427550
)
551+
if result.prompt_logprobs_tensors and num_prompt_logprobs:
552+
if num_prompt_logprobs == -1:
553+
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
554+
result.prompt_logprobs = self._build_prompt_logprobs(
555+
result.prompt_logprobs_tensors, num_prompt_logprobs
556+
)
428557

429558
output[pos] = result
430559
finished.append(i)

fastdeploy/output/token_processor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def _process_batch_output_use_zmq(self, receive_datas):
285285
finished=False,
286286
metrics=metrics,
287287
)
288+
if self.use_logprobs:
289+
if getattr(stream_data, "logprobs", None) is not None:
290+
try:
291+
logprobs_list: LogprobsLists = stream_data.logprobs.tolists()
292+
result.outputs.logprob = float(logprobs_list.logprobs[0][0])
293+
result.outputs.top_logprobs = logprobs_list
294+
except Exception as e:
295+
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
296+
if getattr(stream_data, "prompt_logprobs", None) is not None:
297+
try:
298+
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
299+
except Exception as e:
300+
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
288301
if self.tokens_counter[task_id] == 0:
289302
if task.messages is not None:
290303
result.prompt = task.messages

fastdeploy/worker/output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Logprob(NamedTuple):
3030
decoded_token: Optional[str] = None
3131

3232

33+
PromptLogprobs = list[dict[int, Logprob] | None]
3334
# [{token_id, logprob}] for tokens sampled from the top-k
3435
SampleLogprobs = list[dict[int, Logprob]]
3536

0 commit comments

Comments
 (0)