Skip to content

Commit c0efdd6

Browse files
aarnphmrshaw@neuralmagic.comrussellbrobertgshaw2-redhat
authored
[Fix][Structured Output] using vocab_size to construct matcher (#14868)
Signed-off-by: Russell Bryant <[email protected]> Signed-off-by: Robert Shaw <[email protected]> Signed-off-by: Aaron Pham <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Russell Bryant <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent aaaec52 commit c0efdd6

File tree

7 files changed

+70
-85
lines changed

7 files changed

+70
-85
lines changed

.buildkite/test-pipeline.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ steps:
200200
- pytest -v -s v1/core
201201
- pytest -v -s v1/entrypoints
202202
- pytest -v -s v1/engine
203+
- pytest -v -s v1/entrypoints
203204
- pytest -v -s v1/sample
204205
- pytest -v -s v1/worker
205206
- pytest -v -s v1/structured_output

requirements/common.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer >= 0.10.11, < 0.11
2121
outlines == 0.1.11
2222
lark == 1.2.2
23-
xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64"
23+
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
2424
typing_extensions >= 4.10
2525
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
2626
partial-json-parser # used for parsing partial JSON outputs

tests/model_executor/test_guided_processors.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
34
import pickle
45

56
import pytest
@@ -208,21 +209,24 @@ def test_guided_decoding_backend_options():
208209

209210

210211
def test_pickle_xgrammar_tokenizer_data():
211-
212-
# TODO: move to another test file for xgrammar
213212
try:
214213
import xgrammar as xgr
215214
except ImportError:
216215
pytest.skip("Could not import xgrammar to run test")
217216

218217
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
219218
TokenizerData)
220-
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
219+
tokenizer_data = TokenizerData(
220+
metadata=
221+
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
222+
encoded_vocab=['!', '"', '#', '$', '%'],
223+
)
221224
pickled = pickle.dumps(tokenizer_data)
222225

223226
assert pickled is not None
224227

225228
depickled: TokenizerData = pickle.loads(pickled)
226229

227230
assert depickled is not None
228-
assert depickled.vocab_type == xgr.VocabType.RAW
231+
assert json.loads(
232+
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value

tests/v1/entrypoints/llm/test_struct_output_generate.py

-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
1919
]
2020

21-
# Undo after https://github.com/vllm-project/vllm/pull/14868
22-
pytest.skip(allow_module_level=True)
23-
2421

2522
@pytest.mark.skip_global_cleanup
2623
@pytest.mark.parametrize("guided_decoding_backend",

vllm/model_executor/guided_decoding/__init__.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.model_executor.guided_decoding.utils import (
1010
convert_lark_to_gbnf, grammar_is_likely_lark,
1111
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
12-
from vllm.platforms import CpuArchEnum
1312

1413
if TYPE_CHECKING:
1514
from transformers import PreTrainedTokenizer
@@ -26,7 +25,7 @@ def maybe_backend_fallback(
2625

2726
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
2827
fallback: str) -> None:
29-
"""Change the backend to the specified fallback with a warning log,
28+
"""Change the backend to the specified fallback with a warning log,
3029
or raise a ValueError if the `no-fallback` option is specified."""
3130
if guided_params.no_fallback():
3231
raise ValueError(message)
@@ -53,19 +52,12 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
5352
if guided_params.backend_name == "xgrammar":
5453
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
5554
xgr_installed)
56-
# xgrammar only has x86 wheels for linux, fallback to outlines
57-
from vllm.platforms import current_platform
58-
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
59-
fallback_or_error(guided_params,
60-
"xgrammar is only supported on x86 CPUs.",
61-
"outlines")
6255

6356
# xgrammar doesn't support regex, fallback to outlines
6457
if guided_params.regex is not None:
6558
fallback_or_error(
6659
guided_params,
6760
"xgrammar does not support regex guided decoding.", "outlines")
68-
6961
# xgrammar doesn't support some JSON schema features
7062
elif (guided_params.json is not None
7163
and has_xgrammar_unsupported_json_features(guided_params.json)):

vllm/model_executor/guided_decoding/xgrammar_decoding.py

+58-67
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
from typing import TYPE_CHECKING, Any, List
1010

1111
import torch
12-
from transformers import PreTrainedTokenizerFast
1312

1413
from vllm.logger import init_logger
1514

1615
try:
1716
import xgrammar as xgr
18-
from xgrammar.base import _core as xgr_core
1917
xgr_installed = True
2018
except ImportError:
2119
xgr_installed = False
@@ -35,7 +33,6 @@
3533
logger = init_logger(__name__)
3634

3735

38-
# TODO: passing batch size to max threads here
3936
def get_local_xgrammar_guided_decoding_logits_processor(
4037
guided_params: GuidedDecodingParams,
4138
tokenizer: PreTrainedTokenizer,
@@ -52,65 +49,61 @@ def get_local_xgrammar_guided_decoding_logits_processor(
5249
@dataclass(frozen=True)
5350
class TokenizerData:
5451
"""Immutable container for cached tokenizer data."""
52+
metadata: str
5553
encoded_vocab: list[str] = field(default_factory=list)
56-
stop_token_ids: list[int] | None = None
57-
# These fields are mutually exclusive: `backend_str` is used to create a
58-
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
59-
# used within the constructor of TokenizeInfo
60-
backend_str: str | None = None
61-
vocab_type: xgr.VocabType | None = None
62-
63-
def __post_init__(self):
64-
# Check for mutual exclusive
65-
assert not (self.backend_str and self.vocab_type), \
66-
"backend_str and vocab_type are mutual exclusive"
6754

6855

6956
class TokenizerDataCache:
7057
"""Cache manager for tokenizer data to avoid repeated processing."""
7158
_cache: dict[int, TokenizerData] = {}
7259

7360
@classmethod
74-
def get_tokenizer_data(cls,
75-
tokenizer: PreTrainedTokenizer) -> TokenizerData:
76-
tokenizer_hash = hash(tokenizer)
61+
def get_tokenizer_data(
62+
cls,
63+
tokenizer: PreTrainedTokenizer,
64+
/,
65+
*,
66+
tokenizer_hash: int,
67+
vocab_size: int,
68+
) -> TokenizerData:
7769

7870
if tokenizer_hash not in cls._cache:
79-
# Vendored from xgrammar logic since we cannot pickle the tokenizer
80-
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
71+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
72+
tokenizer,
73+
# NOTE: We will need to use lm_head's vocab_size
74+
# to determine correct special_token_ids for this tokenizer.
75+
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
76+
vocab_size=vocab_size,
77+
)
78+
metadata = json.loads(tokenizer_info.dump_metadata())
79+
80+
# Vendored from xgrammar logic to get encoded_vocab
81+
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
8182
try:
82-
encoded_vocab = [
83-
token for token, _ in sorted(tokenizer.get_vocab().items(),
84-
key=lambda x: x[1])
85-
]
83+
vocab_dict = tokenizer.get_vocab()
8684
except AttributeError as e:
8785
raise ValueError(
8886
f"Cannot get the vocabulary of the tokenizer "
8987
f"{type(tokenizer)}. The tokenizer should have a "
9088
"get_vocab method.") from e
9189

92-
stop_token_ids = None
93-
backend_str = ""
94-
vocab_type = xgr.VocabType.RAW
95-
96-
if stop_token_ids is None and hasattr(
97-
tokenizer,
98-
"eos_token_id") and tokenizer.eos_token_id is not None:
99-
stop_token_ids = [tokenizer.eos_token_id]
100-
101-
if isinstance(tokenizer, PreTrainedTokenizerFast):
102-
backend_str = tokenizer.backend_tokenizer.to_str()
103-
vocab_type = None
90+
# maintain tokenizer's indexing
91+
encoded_vocab = [""] * tokenizer_info.vocab_size
92+
for token, idx in vocab_dict.items():
93+
if idx < tokenizer_info.vocab_size:
94+
encoded_vocab[idx] = token
10495

105-
elif isinstance(tokenizer, MistralTokenizer):
96+
if isinstance(tokenizer, MistralTokenizer):
10697
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
107-
vocab_type = xgr.VocabType.BYTE_FALLBACK
98+
metadata.update({
99+
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
100+
"add_prefix_space": True
101+
})
108102

109103
cls._cache[tokenizer_hash] = TokenizerData(
110104
encoded_vocab=encoded_vocab,
111-
stop_token_ids=stop_token_ids,
112-
backend_str=backend_str,
113-
vocab_type=vocab_type)
105+
metadata=json.dumps(metadata),
106+
)
114107

115108
return cls._cache[tokenizer_hash]
116109

@@ -129,30 +122,15 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
129122
cache_key = str(config.tokenizer_hash)
130123

131124
if cache_key not in cls._cache:
132-
assert config.tokenizer_data is not None
133-
assert config.tokenizer_data.encoded_vocab is not None
134-
135125
config_data = config.tokenizer_data
136126

137127
# In TokenizerDataCache.get_tokenizer_data, a serializable
138128
# tokenizer_data is created and cached. This data is used to build
139129
# a tokenizer_info and create an xgrammar compiler.
140-
# - If tokenizer_data has backend_str set, use
141-
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
142-
# - Otherwise, use the default constructor with vocab_type.
143-
# - xgr_core.TokenizerInfo.from_huggingface !=
144-
# xgr.TokenizerInfo.from_huggingface.
145-
if config_data.backend_str:
146-
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
147-
xgr_core.TokenizerInfo.from_huggingface(
148-
config_data.encoded_vocab, config_data.backend_str,
149-
config.vocab_size, config_data.stop_token_ids))
150-
else:
151-
tokenizer_info = xgr.TokenizerInfo(
152-
config_data.encoded_vocab,
153-
config_data.vocab_type,
154-
vocab_size=config.vocab_size,
155-
stop_token_ids=config_data.stop_token_ids)
130+
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
131+
encoded_vocab=config_data.encoded_vocab,
132+
metadata=config_data.metadata,
133+
)
156134
cls._cache[cache_key] = xgr.GrammarCompiler(
157135
tokenizer_info, max_threads=config.max_threads)
158136

@@ -163,13 +141,12 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
163141
class GrammarConfig:
164142
"""Serializable configuration for grammar compilation"""
165143
tokenizer_hash: int
166-
vocab_size: int
144+
tokenizer_data: TokenizerData
167145
json_str: str | None = None
168146
grammar_str: str | None = None
169147
json_object: bool | None = None
170148
any_whitespace: bool = True
171149
max_threads: int = 8
172-
tokenizer_data: TokenizerData | None = None
173150

174151
@classmethod
175152
def from_guided_params(cls,
@@ -179,7 +156,11 @@ def from_guided_params(cls,
179156
max_threads: int = 8) -> GrammarConfig:
180157

181158
tokenizer_hash = hash(tokenizer)
182-
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
159+
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
160+
tokenizer,
161+
tokenizer_hash=tokenizer_hash,
162+
vocab_size=model_config.hf_text_config.vocab_size,
163+
)
183164

184165
if guided_params.json:
185166
if not isinstance(guided_params.json, str):
@@ -218,7 +199,6 @@ def from_guided_params(cls,
218199
raise ValueError(str(err)) from err
219200

220201
return cls(json_str=json_str,
221-
vocab_size=model_config.hf_text_config.vocab_size,
222202
tokenizer_hash=tokenizer_hash,
223203
max_threads=max_threads,
224204
tokenizer_data=tokenizer_data,
@@ -246,14 +226,12 @@ def from_guided_params(cls,
246226
raise ValueError(str(err)) from err
247227

248228
return cls(grammar_str=grammar_str,
249-
vocab_size=model_config.hf_text_config.vocab_size,
250229
tokenizer_hash=tokenizer_hash,
251230
max_threads=max_threads,
252231
tokenizer_data=tokenizer_data)
253232
elif guided_params.json_object:
254233
return cls(
255234
json_object=True,
256-
vocab_size=model_config.hf_text_config.vocab_size,
257235
tokenizer_hash=tokenizer_hash,
258236
max_threads=max_threads,
259237
tokenizer_data=tokenizer_data,
@@ -267,7 +245,6 @@ def from_guided_params(cls,
267245

268246
return cls(
269247
grammar_str=choice_str,
270-
vocab_size=model_config.hf_text_config.vocab_size,
271248
tokenizer_hash=tokenizer_hash,
272249
max_threads=max_threads,
273250
tokenizer_data=tokenizer_data,
@@ -291,6 +268,13 @@ def choice_as_grammar(choice: List[str] | None) -> str:
291268
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
292269
return grammar
293270

271+
@staticmethod
272+
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
273+
return xgr.TokenizerInfo.from_vocab_and_metadata(
274+
encoded_vocab=tokenizer_data.encoded_vocab,
275+
metadata=tokenizer_data.metadata,
276+
)
277+
294278

295279
@dataclass
296280
class XGrammarLogitsProcessor:
@@ -299,18 +283,25 @@ class XGrammarLogitsProcessor:
299283
reasoner: Reasoner | None = None
300284

301285
ctx: xgr.CompiledGrammar | None = None
286+
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
302287
token_bitmask: torch.Tensor = None # type: ignore[assignment]
303288
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
304289
batch_size: int = field(default=1)
305290
prefilled: bool = field(default=False)
306291

292+
def __post_init__(self):
293+
self.tokenizer_info = self.config.tokenizer_info(
294+
self.config.tokenizer_data)
295+
307296
def __getstate__(self) -> dict[str, Any]:
308297
return {'config': self.config, 'reasoner': self.reasoner}
309298

310299
def __setstate__(self, state: dict[str, Any]):
311300
self.config = state['config']
312301
self.reasoner = state['reasoner']
313302

303+
self.tokenizer_info = GrammarConfig.tokenizer_info(
304+
self.config.tokenizer_data)
314305
self.ctx = None
315306
self.matchers = []
316307
self.batch_size = 1
@@ -352,7 +343,7 @@ def __call__(self, input_ids: list[int],
352343
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
353344
]
354345
self.token_bitmask = xgr.allocate_token_bitmask(
355-
self.batch_size, self.config.vocab_size)
346+
self.batch_size, self.tokenizer_info.vocab_size)
356347

357348
if not self.prefilled:
358349
# Have not sampled a token yet

vllm/v1/structured_output/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _delayed_init(self):
4040
tokenizer_group.ping()
4141

4242
tokenizer = tokenizer_group.get_lora_tokenizer(None)
43-
self.vocab_size = len(tokenizer.get_vocab())
43+
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
4444
if isinstance(tokenizer, MistralTokenizer):
4545
# NOTE: ideally, xgrammar should handle this accordingly.
4646
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98

0 commit comments

Comments
 (0)