Skip to content

Commit bf7f470

Browse files
afeldman-nmnjhillabf149hmellor
authored
[V1] Logits processors extensibility (#19912)
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Andrew Feldman <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent 4fc722e commit bf7f470

File tree

22 files changed

+1313
-335
lines changed

22 files changed

+1313
-335
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ steps:
253253
- pytest -v -s v1/engine
254254
- pytest -v -s v1/entrypoints
255255
- pytest -v -s v1/sample
256+
- pytest -v -s v1/logits_processors
256257
- pytest -v -s v1/worker
257258
- pytest -v -s v1/structured_output
258259
- pytest -v -s v1/spec_decode
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""This example demonstrates instantiating vLLM with a custom logits processor
5+
class object.
6+
7+
For a basic example of implementing a custom logits processor, see
8+
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
9+
10+
For testing purposes, a dummy logits processor is employed which, if
11+
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
12+
will mask out all tokens except `target_token`.
13+
14+
A batch is constructed with `temperature=0.0` and 50% of requests specifying
15+
`target_token`, and for these requests - and *only* these requests - we
16+
expect the `target_token` to be decoded in each step, yielding an output
17+
similar to that shown below:
18+
19+
Generated Outputs:
20+
------------------------------------------------------------
21+
Prompt: 'Hello, my name is'
22+
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
23+
------------------------------------------------------------
24+
Prompt: 'The president of the United States is'
25+
Output: " not a racist. He is a racist.\nHe's a racist because he"
26+
------------------------------------------------------------
27+
Prompt: 'The capital of France is'
28+
Output: ' also also also also also also also also also also also also also
29+
also also also'
30+
------------------------------------------------------------
31+
Prompt: 'The future of AI is'
32+
Output: ' in the hands of the people.\n\nThe future of AI is in the'
33+
------------------------------------------------------------
34+
"""
35+
36+
from typing import Optional
37+
38+
import torch
39+
40+
from vllm import LLM, SamplingParams
41+
from vllm.config import VllmConfig
42+
from vllm.v1.sample.logits_processor import (
43+
BatchUpdate,
44+
LogitsProcessor,
45+
MoveDirectionality,
46+
)
47+
48+
49+
# Hypothetical custom logits processor
50+
class DummyLogitsProcessor(LogitsProcessor):
51+
"""Fake logit processor to support unit testing and examples"""
52+
53+
def __init__(
54+
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
55+
):
56+
self.req_info: dict[int, SamplingParams] = {}
57+
58+
def is_argmax_invariant(self) -> bool:
59+
"""Never impacts greedy sampling"""
60+
return False
61+
62+
def update_state(self, batch_update: Optional[BatchUpdate]):
63+
if not batch_update:
64+
return
65+
66+
# Process added requests.
67+
for index, params, _, _ in batch_update.added:
68+
assert params is not None
69+
if params.extra_args and (
70+
target_token := params.extra_args.get("target_token")
71+
):
72+
self.req_info[index] = target_token
73+
74+
if self.req_info:
75+
# Process removed requests.
76+
for index in batch_update.removed:
77+
self.req_info.pop(index, None)
78+
79+
# Process moved requests, unidirectional move (a->b) and swap
80+
# (a<->b)
81+
for adx, bdx, direct in batch_update.moved:
82+
a_val = self.req_info.pop(adx, None)
83+
b_val = self.req_info.pop(bdx, None)
84+
if a_val is not None:
85+
self.req_info[bdx] = a_val
86+
if direct == MoveDirectionality.SWAP and b_val is not None:
87+
self.req_info[adx] = b_val
88+
89+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
90+
if not self.req_info:
91+
return logits
92+
93+
# Save target values before modification
94+
rows_list = list(self.req_info.keys())
95+
cols = torch.tensor(
96+
[self.req_info[i] for i in rows_list],
97+
dtype=torch.long,
98+
device=logits.device,
99+
)
100+
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
101+
values_to_keep = logits[rows, cols].clone()
102+
103+
# Mask all but target tokens
104+
logits[rows] = float("-inf")
105+
logits[rows, cols] = values_to_keep
106+
107+
return logits
108+
109+
110+
# Sample prompts.
111+
prompts = [
112+
"Hello, my name is",
113+
"The president of the United States is",
114+
"The capital of France is",
115+
"The future of AI is",
116+
]
117+
# Create a mixture of requests which do and don't utilize the dummy logitproc
118+
sampling_params_list = [
119+
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
120+
SamplingParams(temperature=0.0),
121+
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
122+
SamplingParams(temperature=0.0),
123+
]
124+
125+
126+
def main():
127+
# Create an LLM.
128+
llm = LLM(
129+
model="facebook/opt-125m",
130+
logits_processors=[DummyLogitsProcessor],
131+
)
132+
# Generate texts from the prompts.
133+
# The output is a list of RequestOutput objects
134+
# that contain the prompt, generated text, and other information.
135+
outputs = llm.generate(prompts, sampling_params_list)
136+
# Print the outputs.
137+
print("\nGenerated Outputs:\n" + "-" * 60)
138+
for output in outputs:
139+
prompt = output.prompt
140+
generated_text = output.outputs[0].text
141+
print(f"Prompt: {prompt!r}")
142+
print(f"Output: {generated_text!r}")
143+
print("-" * 60)
144+
145+
146+
if __name__ == "__main__":
147+
main()

tests/utils.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414
import warnings
1515
from contextlib import contextmanager, suppress
16+
from multiprocessing import Process
1617
from pathlib import Path
1718
from typing import Any, Callable, Literal, Optional, Union
1819

@@ -76,6 +77,23 @@ def _nvml():
7677
class RemoteOpenAIServer:
7778
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
7879

80+
def _start_server(self, model: str, vllm_serve_args: list[str],
81+
env_dict: Optional[dict[str, str]]) -> None:
82+
"""Subclasses override this method to customize server process launch
83+
"""
84+
env = os.environ.copy()
85+
# the current process might initialize cuda,
86+
# to be safe, we should use spawn method
87+
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
88+
if env_dict is not None:
89+
env.update(env_dict)
90+
self.proc: subprocess.Popen = subprocess.Popen(
91+
["vllm", "serve", model, *vllm_serve_args],
92+
env=env,
93+
stdout=sys.stdout,
94+
stderr=sys.stderr,
95+
)
96+
7997
def __init__(self,
8098
model: str,
8199
vllm_serve_args: list[str],
@@ -128,18 +146,7 @@ def __init__(self,
128146
model_loader = get_model_loader(load_config)
129147
model_loader.download_model(model_config)
130148

131-
env = os.environ.copy()
132-
# the current process might initialize cuda,
133-
# to be safe, we should use spawn method
134-
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
135-
if env_dict is not None:
136-
env.update(env_dict)
137-
self.proc = subprocess.Popen(
138-
["vllm", "serve", model, *vllm_serve_args],
139-
env=env,
140-
stdout=sys.stdout,
141-
stderr=sys.stderr,
142-
)
149+
self._start_server(model, vllm_serve_args, env_dict)
143150
max_wait_seconds = max_wait_seconds or 240
144151
self._wait_for_server(url=self.url_for("health"),
145152
timeout=max_wait_seconds)
@@ -155,6 +162,10 @@ def __exit__(self, exc_type, exc_value, traceback):
155162
# force kill if needed
156163
self.proc.kill()
157164

165+
def _poll(self) -> Optional[int]:
166+
"""Subclasses override this method to customize process polling"""
167+
return self.proc.poll()
168+
158169
def _wait_for_server(self, *, url: str, timeout: float):
159170
# run health check
160171
start = time.time()
@@ -169,7 +180,7 @@ def _wait_for_server(self, *, url: str, timeout: float):
169180
# which means the server is not ready yet.
170181
# the stack trace is not useful, so we suppress it
171182
# by using `raise from None`.
172-
result = self.proc.poll()
183+
result = self._poll()
173184
if result is not None and result != 0:
174185
raise RuntimeError("Server exited unexpectedly.") from None
175186

@@ -205,6 +216,48 @@ def get_async_client(self, **kwargs):
205216
**kwargs)
206217

207218

219+
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
220+
"""Launch test server with custom child process"""
221+
222+
def _start_server(self, model: str, vllm_serve_args: list[str],
223+
env_dict: Optional[dict[str, str]]) -> None:
224+
self.proc: Process = Process(
225+
target=self.child_process_fxn,
226+
args=(env_dict, model,
227+
vllm_serve_args)) # type: ignore[assignment]
228+
self.proc.start()
229+
230+
def __init__(self,
231+
model: str,
232+
vllm_serve_args: list[str],
233+
child_process_fxn: Callable[
234+
[Optional[dict[str, str]], str, list[str]], None],
235+
*,
236+
env_dict: Optional[dict[str, str]] = None,
237+
seed: Optional[int] = 0,
238+
auto_port: bool = True,
239+
max_wait_seconds: Optional[float] = None) -> None:
240+
"""Store custom child process function then invoke superclass
241+
constructor which will indirectly launch it."""
242+
self.child_process_fxn = child_process_fxn
243+
super().__init__(model=model,
244+
vllm_serve_args=vllm_serve_args,
245+
env_dict=env_dict,
246+
seed=seed,
247+
auto_port=auto_port,
248+
max_wait_seconds=max_wait_seconds)
249+
250+
def _poll(self) -> Optional[int]:
251+
return self.proc.exitcode
252+
253+
def __exit__(self, exc_type, exc_value, traceback):
254+
self.proc.terminate()
255+
self.proc.join(8)
256+
if self.proc.is_alive():
257+
# force kill if needed
258+
self.proc.kill()
259+
260+
208261
def _test_completion(
209262
client: openai.OpenAI,
210263
model: str,

tests/v1/logits_processors/__init__.py

Whitespace-only changes.

tests/v1/sample/test_logits_processors.py renamed to tests/v1/logits_processors/test_correctness.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
import pytest
1010
import torch
1111

12+
from tests.utils import create_new_process_for_each_test
1213
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
1314
create_penalty_tensor,
1415
create_prompt_tokens_tensor,
1516
fake_apply_logitsprocs,
1617
fake_update_logitsprocs_state)
18+
from vllm.config import VllmConfig
1719
from vllm.platforms import current_platform
1820
from vllm.sampling_params import SamplingParams
1921
from vllm.utils import is_pin_memory_available
@@ -24,7 +26,7 @@
2426
MinPLogitsProcessor,
2527
MinTokensLogitsProcessor,
2628
MoveDirectionality,
27-
init_builtin_logitsprocs)
29+
build_logitsprocs)
2830
# yapf: enable
2931
from vllm.v1.sample.metadata import SamplingMetadata
3032

@@ -53,6 +55,7 @@ class LogitsProcsRequestParams:
5355
workload_index: int
5456
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
5557
out_tokens: list[int] # Output tokens required for min tokens test
58+
prompt_tokens: list[int] # Dummy prompt tokens placeholder
5659
params: SamplingParams # Settings customized for logitproc
5760

5861
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
@@ -63,6 +66,7 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType):
6366
# don't matter *for these tests* so use 0 as a dummy value
6467
self.out_tokens = ([0] *
6568
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
69+
self.prompt_tokens = []
6670
self.params = _sampling_params_from_logitproc(logitproc_type)
6771

6872
def __str__(self):
@@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata(
8892
vocab_size,
8993
size=np.random.randint(
9094
1, MAX_NUM_PROMPT_TOKENS)).tolist())
91-
logitsprocs = init_builtin_logitsprocs(
92-
pin_memory_available=PIN_MEMORY_AVAILABLE,
93-
max_num_reqs=MAX_NUM_REQS + 1,
94-
device=device)
95-
95+
logitsprocs = build_logitsprocs(
96+
vllm_config=VllmConfig(),
97+
device=device,
98+
is_pin_memory=PIN_MEMORY_AVAILABLE,
99+
is_pooling_model=False,
100+
)
96101
fake_sampling_metadata = SamplingMetadata(
97102
temperature=torch.full((batch_size, ), 0.0),
98103
all_greedy=True,
@@ -462,15 +467,17 @@ def _generate_fake_step_update(
462467
# Replace as many removed requests as possible with added requests
463468
add_remove_idx = batch_update_builder.pop_removed()
464469
batch_update_builder.added.append(
465-
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
470+
(add_remove_idx, add_req_params.params,
471+
add_req_params.prompt_tokens, add_req_params.out_tokens))
466472
persistent_batch[add_remove_idx] = add_req_params
467473

468474
# Append remaining added requests to end of batch
469475
add_reqs_append = workload_params[(wdx +
470476
num_step_add_replace):(wdx +
471477
num_step_add)]
472478
batch_update_builder.added.extend([
473-
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
479+
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
480+
add_req_params.out_tokens)
474481
for adx, add_req_params in enumerate(add_reqs_append)
475482
])
476483
persistent_batch.extend(add_reqs_append)
@@ -561,6 +568,7 @@ def _assert_valid(
561568
step_idx=step_idx)
562569

563570

571+
@create_new_process_for_each_test()
564572
@pytest.mark.parametrize("device", CUDA_DEVICES)
565573
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
566574
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())

0 commit comments

Comments
 (0)