|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import itertools |
19 | 20 | import logging |
20 | 21 | import threading |
21 | 22 | import time |
22 | 23 | import traceback |
23 | 24 | import uuid |
| 25 | +from collections.abc import Iterable |
24 | 26 | from typing import Any, Optional, Union |
25 | 27 |
|
26 | 28 | from pydantic import ValidationError |
|
37 | 39 | llm_logger, |
38 | 40 | retrive_model_from_server, |
39 | 41 | ) |
40 | | -from fastdeploy.worker.output import Logprob, LogprobsLists |
| 42 | +from fastdeploy.worker.output import ( |
| 43 | + Logprob, |
| 44 | + LogprobsLists, |
| 45 | + LogprobsTensors, |
| 46 | + PromptLogprobs, |
| 47 | +) |
41 | 48 |
|
42 | 49 | root_logger = logging.getLogger() |
43 | 50 | for handler in root_logger.handlers[:]: |
44 | 51 | if isinstance(handler, logging.StreamHandler): |
45 | 52 | root_logger.removeHandler(handler) |
46 | 53 |
|
| 54 | +NONES = itertools.repeat(None) |
| 55 | + |
47 | 56 |
|
48 | 57 | class LLM: |
49 | 58 | """ |
@@ -189,12 +198,17 @@ def generate( |
189 | 198 | req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params) |
190 | 199 |
|
191 | 200 | topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs |
| 201 | + num_prompt_logprobs = ( |
| 202 | + sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs |
| 203 | + ) |
192 | 204 |
|
193 | 205 | # get output |
194 | 206 | if stream: |
195 | 207 | return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs) |
196 | 208 | else: |
197 | | - outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs) |
| 209 | + outputs = self._run_engine( |
| 210 | + req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs |
| 211 | + ) |
198 | 212 | for i in range(len(outputs)): |
199 | 213 | outputs[i].prompt = prompts[i] |
200 | 214 | return outputs |
@@ -321,6 +335,27 @@ def _add_request( |
321 | 335 | current_sampling_params = sampling_params[i] |
322 | 336 | else: |
323 | 337 | current_sampling_params = sampling_params |
| 338 | + if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None: |
| 339 | + raise ValueError("prompt_logprobs is not supported with streaming.") |
| 340 | + max_logprobs = self.llm_engine.cfg.model_config.max_logprobs |
| 341 | + if max_logprobs == -1: |
| 342 | + max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size |
| 343 | + if current_sampling_params.logprobs is not None: |
| 344 | + num_logprobs = current_sampling_params.logprobs |
| 345 | + if num_logprobs == -1: |
| 346 | + num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size |
| 347 | + if num_logprobs > max_logprobs: |
| 348 | + raise ValueError( |
| 349 | + f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})." |
| 350 | + ) |
| 351 | + if current_sampling_params.prompt_logprobs is not None: |
| 352 | + num_prompt_logprobs = current_sampling_params.prompt_logprobs |
| 353 | + if num_prompt_logprobs == -1: |
| 354 | + num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size |
| 355 | + if num_prompt_logprobs > max_logprobs: |
| 356 | + raise ValueError( |
| 357 | + f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})." |
| 358 | + ) |
324 | 359 | if current_sampling_params.guided_decoding is not None: |
325 | 360 | guided_decoding_dict = current_sampling_params.guided_decoding.to_dict() |
326 | 361 | tasks.update(guided_decoding_dict) |
@@ -377,7 +412,93 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i |
377 | 412 | except Exception as e: |
378 | 413 | llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}") |
379 | 414 |
|
380 | | - def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None): |
| 415 | + def _build_prompt_logprobs( |
| 416 | + self, |
| 417 | + prompt_logprobs_tensors: LogprobsTensors, |
| 418 | + num_prompt_logprobs: int, |
| 419 | + ): |
| 420 | + """Update with prompt logprobs from worker. |
| 421 | + Args: |
| 422 | + prompt_logprobs_tensors: tuple containing the prompt logprobs |
| 423 | + tensors. |
| 424 | + """ |
| 425 | + |
| 426 | + token_ids, logprobs, ranks = prompt_logprobs_tensors |
| 427 | + |
| 428 | + # Detokenize non-incrementally. |
| 429 | + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] |
| 430 | + decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()] |
| 431 | + |
| 432 | + # Recover shapes. |
| 433 | + num_prompt_tokens, num_logprobs = logprobs.shape |
| 434 | + |
| 435 | + # Pythonize the torch tensors. |
| 436 | + prompt_token_ranks = ranks.tolist() |
| 437 | + prompt_logprobs = logprobs.tolist() |
| 438 | + token_ids = token_ids.tolist() |
| 439 | + result: Optional[PromptLogprobs] = [] |
| 440 | + # Make Logprob for each position. |
| 441 | + for pos in range(num_prompt_tokens): |
| 442 | + # Handle flattening. |
| 443 | + offset = pos * num_logprobs |
| 444 | + offset_end = offset + num_logprobs |
| 445 | + decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] |
| 446 | + |
| 447 | + # Update with the Logprob dictionary for this pos. |
| 448 | + result.append( |
| 449 | + self._make_logprob_dict( |
| 450 | + prompt_logprobs[pos], |
| 451 | + token_ids[pos], |
| 452 | + decoded_tokens_for_pos, |
| 453 | + prompt_token_ranks[pos], |
| 454 | + num_prompt_logprobs, |
| 455 | + ) |
| 456 | + ) |
| 457 | + return result |
| 458 | + |
| 459 | + @staticmethod |
| 460 | + def _make_logprob_dict( |
| 461 | + logprobs: list[float], |
| 462 | + logprob_token_ids: list[int], |
| 463 | + decoded_tokens: Iterable[str | None], |
| 464 | + rank: int, |
| 465 | + num_logprobs: int, |
| 466 | + ) -> dict[int, Logprob]: |
| 467 | + """Make a Logprob dictionary for a position. |
| 468 | + Args: |
| 469 | + logprobs: list of log probabilities |
| 470 | + logprob_token_ids: list of top token ids |
| 471 | + decoded_tokens: list of decoded top tokens |
| 472 | + rank: rank of the sampled token |
| 473 | + num_logprobs: number of logprobs requested |
| 474 | + by the user (in addition to sampled logprob) |
| 475 | + Returns: |
| 476 | + dict[token id, Logprob] |
| 477 | + """ |
| 478 | + if num_logprobs == -1: |
| 479 | + num_logprobs = len(logprobs) |
| 480 | + # We do not need a special case for the sampled token |
| 481 | + # being in the topk, since inserting duplicated data |
| 482 | + # into a dictionary twice is the same as doing it once. |
| 483 | + topk_ranks = range(1, num_logprobs + 1) |
| 484 | + ranks = itertools.chain((rank,), topk_ranks) |
| 485 | + |
| 486 | + return { |
| 487 | + token_id: Logprob( |
| 488 | + logprob=logprob, |
| 489 | + rank=rank, |
| 490 | + decoded_token=token, |
| 491 | + ) |
| 492 | + for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) |
| 493 | + } |
| 494 | + |
| 495 | + def _run_engine( |
| 496 | + self, |
| 497 | + req_ids: list[str], |
| 498 | + use_tqdm: bool, |
| 499 | + topk_logprobs: Optional[int] = None, |
| 500 | + num_prompt_logprobs: Optional[int] = None, |
| 501 | + ): |
381 | 502 | """ |
382 | 503 | 运行引擎,并返回结果列表。 |
383 | 504 |
|
@@ -422,9 +543,17 @@ def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optiona |
422 | 543 |
|
423 | 544 | # filter logprobs |
424 | 545 | if result.outputs.top_logprobs and topk_logprobs: |
| 546 | + if topk_logprobs == -1: |
| 547 | + topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size |
425 | 548 | result.outputs.logprobs = self._build_sample_logprobs( |
426 | 549 | result.outputs.top_logprobs, topk_logprobs |
427 | 550 | ) |
| 551 | + if result.prompt_logprobs_tensors and num_prompt_logprobs: |
| 552 | + if num_prompt_logprobs == -1: |
| 553 | + num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size |
| 554 | + result.prompt_logprobs = self._build_prompt_logprobs( |
| 555 | + result.prompt_logprobs_tensors, num_prompt_logprobs |
| 556 | + ) |
428 | 557 |
|
429 | 558 | output[pos] = result |
430 | 559 | finished.append(i) |
|
0 commit comments