Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
96dc2f0
feat: limit thinking tokens
llsj14 Jul 12, 2025
5273b03
remove comment
llsj14 Jul 12, 2025
64c5848
update states only in update_state method
llsj14 Jul 14, 2025
43b494f
make precommit and lint
llsj14 Jul 14, 2025
8849eb5
revert change of deepseek reasoning parser
llsj14 Jul 15, 2025
13d75c9
support think start/end as token sequences
llsj14 Jul 16, 2025
9440408
refactor and change logic faster
llsj14 Jul 17, 2025
e9a8198
rename parameter and logit processor
llsj14 Jul 18, 2025
ceb241f
add reasoning effort param
llsj14 Jul 18, 2025
cf2f127
remove constraint of the reasoning model
llsj14 Jul 18, 2025
e21d935
update logit processor
llsj14 Jul 19, 2025
f4470c0
pass ruff
llsj14 Jul 19, 2025
83510b4
pass precommit
llsj14 Jul 19, 2025
e552fb6
fix format
llsj14 Jul 19, 2025
373e10f
fix: loads none error
llsj14 Jul 21, 2025
a832beb
fix return type
llsj14 Jul 21, 2025
0f58220
fix error
llsj14 Jul 21, 2025
d09361f
update ReasoningConfig handling
llsj14 Jul 21, 2025
f30e218
fix config and EngineArgs
llsj14 Jul 21, 2025
ee3a650
simplify reasoning config checks and fix errors
llsj14 Jul 22, 2025
459bb17
fix type of think token ids
llsj14 Jul 22, 2025
49a9959
reafctor ThinkingTokenBudgetLogitsProcessor
llsj14 Jul 27, 2025
a098139
fix import error from rebase
llsj14 Jul 27, 2025
63f9667
fix: remove duplicate reasoning_effort field in ChatCompletionRequest
llsj14 Aug 16, 2025
0fd30b9
fix runtime error after rebase
llsj14 Aug 17, 2025
fff1161
check reasoning is enabled
llsj14 Aug 18, 2025
428622b
add test and implement processor with incremental token processing op…
llsj14 Aug 19, 2025
857540b
remove connection between reasoning_effort and thinking_token_budget
llsj14 Aug 20, 2025
dbd0a79
fix: support corner cases
llsj14 Aug 23, 2025
c1b91b0
cleanup unused parameters
llsj14 Aug 23, 2025
1769198
optimize speed up performance while apply logit processor
llsj14 Aug 23, 2025
737ac04
utilize logits processor when it is needed, not every step for speed up
llsj14 Sep 4, 2025
e069345
refactor processor
llsj14 Sep 5, 2025
b596d23
add comment on state
llsj14 Sep 17, 2025
0e222cb
fix tokenizer init bug
llsj14 Sep 17, 2025
90975b8
make precommit
llsj14 Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 165 additions & 16 deletions tests/v1/logits_processors/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
build_logitsprocs)
from vllm.v1.sample.logits_processor import (
BatchUpdate, BatchUpdateBuilder, LogitBiasLogitsProcessor, LogitsProcessor,
MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality,
ThinkingTokenBudgetLogitsProcessor, build_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata

Expand All @@ -43,6 +40,11 @@
REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none"

# ThinkingTokenBudgetLogitsProcessor testing constants
THINKING_TOKEN_BUDGET = 5
THINK_START_TOKEN_ID = 999
THINK_END_TOKEN_ID = 998

# LogitsProcessor subclass or "none"
LogitprocType = Union[type[LogitsProcessor], str]

Expand All @@ -62,10 +64,24 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType):
self.workload_index = workload_index
self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
# threshold which will be used in testing.
# Generate diverse random tokens for all processors (more realistic)
num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)
if num_tokens > 0:
# Use diverse random tokens
self.out_tokens = [
random.randint(1, 950) for _ in range(num_tokens)
]
# Set first token for ThinkingTokenBudget testing
is_thinking_processor = (logitproc_type
is ThinkingTokenBudgetLogitsProcessor or
(hasattr(logitproc_type, '__name__')
and logitproc_type.__name__
== 'ThinkingTokenBudgetLogitsProcessor'))
if is_thinking_processor:
self.out_tokens[0] = THINK_START_TOKEN_ID
else:
self.out_tokens = []
self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type)

Expand All @@ -75,6 +91,15 @@ def __str__(self):
return f"MyClass({summ})"


class MockReasoningConfig:
"""Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor."""
think_start_token_ids = [THINK_START_TOKEN_ID]
think_end_token_ids = [THINK_END_TOKEN_ID]

def is_thinking_enabled(self) -> bool:
return True


def _generate_fake_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand All @@ -92,8 +117,12 @@ def _generate_fake_sampling_metadata(
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())

vllm_config = VllmConfig()
vllm_config.reasoning_config = MockReasoningConfig()

logitsprocs = build_logitsprocs(
vllm_config=VllmConfig(),
vllm_config=vllm_config,
device=device,
is_pin_memory=PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
Expand Down Expand Up @@ -368,6 +397,115 @@ def _min_tokens_validate(
step_idx=step_idx)


def _thinking_budget_params(kwargs: dict) -> None:
"""Set SamplingParams kwargs for thinking token budget tests"""
kwargs["thinking_token_budget"] = THINKING_TOKEN_BUDGET


def _thinking_budget_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate thinking token budget processor behavior"""
# Get the ThinkingTokenBudgetLogitsProcessor instance
tb_processor: ThinkingTokenBudgetLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor))

# Get current request state
state = tb_processor._state.get(batch_index)
params = request_params.params

# Validate thinking token budget configuration
if hasattr(params,
'thinking_token_budget') and params.thinking_token_budget:
# State should exist for requests with thinking_token_budget
if state is None:
_raise_error_invalid(msg_suffix=(
f"Expected state for batch {batch_index} "
f"with thinking_token_budget={params.thinking_token_budget}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)

# Validate budget matches what was set
expected_budget = params.thinking_token_budget
actual_budget = state["thinking_token_budget"]

if actual_budget != expected_budget:
_raise_error_invalid(
msg_suffix=(f"Budget mismatch: expected {expected_budget}, "
f"got {actual_budget}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)

# Check if we're in thinking mode and validate token counting
output_tokens = request_params.out_tokens

# Find if thinking has started in output tokens
thinking_started = False
start_tokens = tb_processor.think_start_token_ids

if len(start_tokens) > 0:
for i in range(len(output_tokens) - len(start_tokens) + 1):
if output_tokens[i:i + len(start_tokens)] == start_tokens:
thinking_started = True
break

if thinking_started:
# If budget is exceeded, validate end token forcing
think_count = state["think_count"]
budget = state["thinking_token_budget"]

if think_count >= budget:
if not state["in_end"]:
_raise_error_invalid(
msg_suffix=(f"Budget exceeded ({think_count} >= "
f"{budget}) but not "
"forcing end tokens"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)

# Validate that only end tokens are allowed
end_tokens = tb_processor.think_end_token_ids
if len(end_tokens) > 0:
expected_end_token_id = end_tokens[min(
state["end_count"],
len(end_tokens) - 1)]

# Check logits masking
batch_logits = logits_new[batch_index]
for token_id in range(len(batch_logits)):
logit_value = batch_logits[token_id]

if token_id == expected_end_token_id:
# End token should not be masked
if logit_value == -float("inf"):
_raise_error_invalid(
msg_suffix=(
f"End token {token_id} should not be "
"masked but is"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
# All other tokens should be masked when forcing end
if logit_value != -float("inf"):
_raise_error_invalid(
msg_suffix=(
f"Token {token_id} should be masked "
f"when forcing end tokens, but "
f"logit={logit_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)


def _none_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
Expand Down Expand Up @@ -413,16 +551,27 @@ class LogitsprocTestHelpers(NamedTuple):
MinTokensLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
eval_fxn=_min_tokens_validate),
ThinkingTokenBudgetLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_thinking_budget_params,
eval_fxn=_thinking_budget_validate),
}


def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
] + [logitsprocs_types]

# Isolate ThinkingTokenBudgetLogitsProcessor from all other processors
# to avoid unexpected modification of logits interference
thinking_processor = ThinkingTokenBudgetLogitsProcessor
other_processors = [
p for p in logitsprocs_types
if p != STR_NO_LOGITPROC and p != thinking_processor
]

return ([[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in other_processors] +
[other_processors] + [[thinking_processor]])


def _generate_fake_step_update(
Expand Down
25 changes: 25 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,29 @@ def _parse_collect_detailed_traces(self):
self.collect_detailed_traces[0].split(","))


@config
@dataclass
class ReasoningConfig:
"""Configuration for reasoning models."""

think_start_str: Optional[str] = None
"""String that indicates the start of reasoning."""
think_end_str: Optional[str] = None
"""String that indicates the end of reasoning."""
think_start_token_ids: Optional[list[int]] = None
"""Token ID that indicates the start of reasoning."""
think_end_token_ids: Optional[list[int]] = None
"""Token ID that indicates the end of reasoning."""

def is_thinking_enabled(self) -> bool:
"""Check if both start and end thinking token IDs
are set to enable thinking token budget logic."""
return (self.think_start_token_ids is not None
and self.think_end_token_ids is not None
and len(self.think_start_token_ids) > 0
and len(self.think_end_token_ids) > 0)


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
Expand Down Expand Up @@ -2473,6 +2496,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
"""The configurations for event publishing."""
reasoning_config: ReasoningConfig = field(default_factory=ReasoningConfig)
"""The configurations for reasoning model."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
Expand Down
12 changes: 9 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs)
ReasoningConfig, RunnerOption, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field
Expand Down Expand Up @@ -449,6 +449,9 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
kv_events_config: Optional[KVEventsConfig] = None

reasoning_config: ReasoningConfig = get_field(VllmConfig,
"reasoning_config")

generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = \
Expand Down Expand Up @@ -932,6 +935,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--reasoning-config",
**vllm_kwargs["reasoning_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])

Expand Down Expand Up @@ -1452,6 +1457,7 @@ def create_engine_config(
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
reasoning_config=self.reasoning_config,
additional_config=self.additional_config,
)

Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ChatCompletionNamedToolChoiceParam,
]] = "none"
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
thinking_token_budget: Optional[int] = None
include_reasoning: bool = True

# NOTE this will be ignored by vLLM -- the model determines the behavior
Expand Down Expand Up @@ -731,6 +732,7 @@ def to_sampling_params(
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
bad_words= self.bad_words,
thinking_token_budget=self.thinking_token_budget,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
)
Expand Down
6 changes: 6 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: Optional[list[list[int]]] = None

thinking_token_budget: Optional[int] = None
"""Maximum number of tokens allowed for thinking operations."""

@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand All @@ -232,6 +235,7 @@ def from_optional(
stop: Optional[Union[str, list[str]]] = None,
stop_token_ids: Optional[list[int]] = None,
bad_words: Optional[list[str]] = None,
thinking_token_budget: Optional[int] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
Expand Down Expand Up @@ -276,6 +280,7 @@ def from_optional(
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
thinking_token_budget=thinking_token_budget,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
Expand Down Expand Up @@ -549,6 +554,7 @@ def __repr__(self) -> str:
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
f"thinking_token_budget={self.thinking_token_budget}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/sample/logits_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
process_dict_updates)
from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor,
ThinkingTokenBudgetLogitsProcessor, process_dict_updates)
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
Expand All @@ -39,6 +38,7 @@
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
ThinkingTokenBudgetLogitsProcessor,
]


Expand Down Expand Up @@ -290,5 +290,5 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP",
"AdapterLogitsProcessor"
"AdapterLogitsProcessor", "ThinkingTokenBudgetLogitsProcessor"
]
Loading