Skip to content

Commit

Permalink
[Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS (#12368)
Browse files Browse the repository at this point in the history
Signed-off-by: Aviv Keshet <[email protected]>
  • Loading branch information
akeshet authored Feb 5, 2025
1 parent 75e9430 commit b3a0d01
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
Expand Down Expand Up @@ -282,6 +283,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_LOGGING_PREFIX":
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),

# if set, vllm will call logits processors in a thread pool with this many
# threads. This is useful when using custom logits processors that either
# (a) launch additional CUDA kernels or (b) do significant CPU-bound work
# while not holding the python GIL, or both.
"VLLM_LOGITS_PROCESSOR_THREADS":
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,

# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
Expand Down
46 changes: 35 additions & 11 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import torch
Expand All @@ -15,6 +16,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform

_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)


class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
Expand Down Expand Up @@ -135,6 +141,7 @@ def _apply_logits_processors(
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
Expand All @@ -148,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids

for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)

logits[logits_row_idx] = logits_row
if _logits_processor_threadpool is not None:
logits_row_ids_and_logits_row_futures.append(
(logits_row_idx,
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)

logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)

for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()

if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits


def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row

0 comments on commit b3a0d01

Please sign in to comment.