9
9
from typing import TYPE_CHECKING , Any , List
10
10
11
11
import torch
12
- from transformers import PreTrainedTokenizerFast
13
12
14
13
from vllm .logger import init_logger
15
14
16
15
try :
17
16
import xgrammar as xgr
18
- from xgrammar .base import _core as xgr_core
19
17
xgr_installed = True
20
18
except ImportError :
21
19
xgr_installed = False
35
33
logger = init_logger (__name__ )
36
34
37
35
38
- # TODO: passing batch size to max threads here
39
36
def get_local_xgrammar_guided_decoding_logits_processor (
40
37
guided_params : GuidedDecodingParams ,
41
38
tokenizer : PreTrainedTokenizer ,
@@ -52,65 +49,61 @@ def get_local_xgrammar_guided_decoding_logits_processor(
52
49
@dataclass (frozen = True )
53
50
class TokenizerData :
54
51
"""Immutable container for cached tokenizer data."""
52
+ metadata : str
55
53
encoded_vocab : list [str ] = field (default_factory = list )
56
- stop_token_ids : list [int ] | None = None
57
- # These fields are mutually exclusive: `backend_str` is used to create a
58
- # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
59
- # used within the constructor of TokenizeInfo
60
- backend_str : str | None = None
61
- vocab_type : xgr .VocabType | None = None
62
-
63
- def __post_init__ (self ):
64
- # Check for mutual exclusive
65
- assert not (self .backend_str and self .vocab_type ), \
66
- "backend_str and vocab_type are mutual exclusive"
67
54
68
55
69
56
class TokenizerDataCache :
70
57
"""Cache manager for tokenizer data to avoid repeated processing."""
71
58
_cache : dict [int , TokenizerData ] = {}
72
59
73
60
@classmethod
74
- def get_tokenizer_data (cls ,
75
- tokenizer : PreTrainedTokenizer ) -> TokenizerData :
76
- tokenizer_hash = hash (tokenizer )
61
+ def get_tokenizer_data (
62
+ cls ,
63
+ tokenizer : PreTrainedTokenizer ,
64
+ / ,
65
+ * ,
66
+ tokenizer_hash : int ,
67
+ vocab_size : int ,
68
+ ) -> TokenizerData :
77
69
78
70
if tokenizer_hash not in cls ._cache :
79
- # Vendored from xgrammar logic since we cannot pickle the tokenizer
80
- # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
71
+ tokenizer_info = xgr .TokenizerInfo .from_huggingface (
72
+ tokenizer ,
73
+ # NOTE: We will need to use lm_head's vocab_size
74
+ # to determine correct special_token_ids for this tokenizer.
75
+ # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
76
+ vocab_size = vocab_size ,
77
+ )
78
+ metadata = json .loads (tokenizer_info .dump_metadata ())
79
+
80
+ # Vendored from xgrammar logic to get encoded_vocab
81
+ # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
81
82
try :
82
- encoded_vocab = [
83
- token for token , _ in sorted (tokenizer .get_vocab ().items (),
84
- key = lambda x : x [1 ])
85
- ]
83
+ vocab_dict = tokenizer .get_vocab ()
86
84
except AttributeError as e :
87
85
raise ValueError (
88
86
f"Cannot get the vocabulary of the tokenizer "
89
87
f"{ type (tokenizer )} . The tokenizer should have a "
90
88
"get_vocab method." ) from e
91
89
92
- stop_token_ids = None
93
- backend_str = ""
94
- vocab_type = xgr .VocabType .RAW
95
-
96
- if stop_token_ids is None and hasattr (
97
- tokenizer ,
98
- "eos_token_id" ) and tokenizer .eos_token_id is not None :
99
- stop_token_ids = [tokenizer .eos_token_id ]
100
-
101
- if isinstance (tokenizer , PreTrainedTokenizerFast ):
102
- backend_str = tokenizer .backend_tokenizer .to_str ()
103
- vocab_type = None
90
+ # maintain tokenizer's indexing
91
+ encoded_vocab = ["" ] * tokenizer_info .vocab_size
92
+ for token , idx in vocab_dict .items ():
93
+ if idx < tokenizer_info .vocab_size :
94
+ encoded_vocab [idx ] = token
104
95
105
- elif isinstance (tokenizer , MistralTokenizer ):
96
+ if isinstance (tokenizer , MistralTokenizer ):
106
97
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
107
- vocab_type = xgr .VocabType .BYTE_FALLBACK
98
+ metadata .update ({
99
+ "vocab_type" : xgr .VocabType .BYTE_FALLBACK ,
100
+ "add_prefix_space" : True
101
+ })
108
102
109
103
cls ._cache [tokenizer_hash ] = TokenizerData (
110
104
encoded_vocab = encoded_vocab ,
111
- stop_token_ids = stop_token_ids ,
112
- backend_str = backend_str ,
113
- vocab_type = vocab_type )
105
+ metadata = json .dumps (metadata ),
106
+ )
114
107
115
108
return cls ._cache [tokenizer_hash ]
116
109
@@ -129,30 +122,15 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
129
122
cache_key = str (config .tokenizer_hash )
130
123
131
124
if cache_key not in cls ._cache :
132
- assert config .tokenizer_data is not None
133
- assert config .tokenizer_data .encoded_vocab is not None
134
-
135
125
config_data = config .tokenizer_data
136
126
137
127
# In TokenizerDataCache.get_tokenizer_data, a serializable
138
128
# tokenizer_data is created and cached. This data is used to build
139
129
# a tokenizer_info and create an xgrammar compiler.
140
- # - If tokenizer_data has backend_str set, use
141
- # xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
142
- # - Otherwise, use the default constructor with vocab_type.
143
- # - xgr_core.TokenizerInfo.from_huggingface !=
144
- # xgr.TokenizerInfo.from_huggingface.
145
- if config_data .backend_str :
146
- tokenizer_info = xgr .TokenizerInfo ._create_from_handle (
147
- xgr_core .TokenizerInfo .from_huggingface (
148
- config_data .encoded_vocab , config_data .backend_str ,
149
- config .vocab_size , config_data .stop_token_ids ))
150
- else :
151
- tokenizer_info = xgr .TokenizerInfo (
152
- config_data .encoded_vocab ,
153
- config_data .vocab_type ,
154
- vocab_size = config .vocab_size ,
155
- stop_token_ids = config_data .stop_token_ids )
130
+ tokenizer_info = xgr .TokenizerInfo .from_vocab_and_metadata (
131
+ encoded_vocab = config_data .encoded_vocab ,
132
+ metadata = config_data .metadata ,
133
+ )
156
134
cls ._cache [cache_key ] = xgr .GrammarCompiler (
157
135
tokenizer_info , max_threads = config .max_threads )
158
136
@@ -163,13 +141,12 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
163
141
class GrammarConfig :
164
142
"""Serializable configuration for grammar compilation"""
165
143
tokenizer_hash : int
166
- vocab_size : int
144
+ tokenizer_data : TokenizerData
167
145
json_str : str | None = None
168
146
grammar_str : str | None = None
169
147
json_object : bool | None = None
170
148
any_whitespace : bool = True
171
149
max_threads : int = 8
172
- tokenizer_data : TokenizerData | None = None
173
150
174
151
@classmethod
175
152
def from_guided_params (cls ,
@@ -179,7 +156,11 @@ def from_guided_params(cls,
179
156
max_threads : int = 8 ) -> GrammarConfig :
180
157
181
158
tokenizer_hash = hash (tokenizer )
182
- tokenizer_data = TokenizerDataCache .get_tokenizer_data (tokenizer )
159
+ tokenizer_data = TokenizerDataCache .get_tokenizer_data (
160
+ tokenizer ,
161
+ tokenizer_hash = tokenizer_hash ,
162
+ vocab_size = model_config .hf_text_config .vocab_size ,
163
+ )
183
164
184
165
if guided_params .json :
185
166
if not isinstance (guided_params .json , str ):
@@ -218,7 +199,6 @@ def from_guided_params(cls,
218
199
raise ValueError (str (err )) from err
219
200
220
201
return cls (json_str = json_str ,
221
- vocab_size = model_config .hf_text_config .vocab_size ,
222
202
tokenizer_hash = tokenizer_hash ,
223
203
max_threads = max_threads ,
224
204
tokenizer_data = tokenizer_data ,
@@ -246,14 +226,12 @@ def from_guided_params(cls,
246
226
raise ValueError (str (err )) from err
247
227
248
228
return cls (grammar_str = grammar_str ,
249
- vocab_size = model_config .hf_text_config .vocab_size ,
250
229
tokenizer_hash = tokenizer_hash ,
251
230
max_threads = max_threads ,
252
231
tokenizer_data = tokenizer_data )
253
232
elif guided_params .json_object :
254
233
return cls (
255
234
json_object = True ,
256
- vocab_size = model_config .hf_text_config .vocab_size ,
257
235
tokenizer_hash = tokenizer_hash ,
258
236
max_threads = max_threads ,
259
237
tokenizer_data = tokenizer_data ,
@@ -267,7 +245,6 @@ def from_guided_params(cls,
267
245
268
246
return cls (
269
247
grammar_str = choice_str ,
270
- vocab_size = model_config .hf_text_config .vocab_size ,
271
248
tokenizer_hash = tokenizer_hash ,
272
249
max_threads = max_threads ,
273
250
tokenizer_data = tokenizer_data ,
@@ -291,6 +268,13 @@ def choice_as_grammar(choice: List[str] | None) -> str:
291
268
grammar = ('root ::= ' + ' | ' .join (f'"{ c } "' for c in escaped_choices ))
292
269
return grammar
293
270
271
+ @staticmethod
272
+ def tokenizer_info (tokenizer_data : TokenizerData ) -> xgr .TokenizerInfo :
273
+ return xgr .TokenizerInfo .from_vocab_and_metadata (
274
+ encoded_vocab = tokenizer_data .encoded_vocab ,
275
+ metadata = tokenizer_data .metadata ,
276
+ )
277
+
294
278
295
279
@dataclass
296
280
class XGrammarLogitsProcessor :
@@ -299,18 +283,25 @@ class XGrammarLogitsProcessor:
299
283
reasoner : Reasoner | None = None
300
284
301
285
ctx : xgr .CompiledGrammar | None = None
286
+ tokenizer_info : xgr .TokenizerInfo = None # type: ignore[assignment]
302
287
token_bitmask : torch .Tensor = None # type: ignore[assignment]
303
288
matchers : list [xgr .GrammarMatcher ] = field (default_factory = list )
304
289
batch_size : int = field (default = 1 )
305
290
prefilled : bool = field (default = False )
306
291
292
+ def __post_init__ (self ):
293
+ self .tokenizer_info = self .config .tokenizer_info (
294
+ self .config .tokenizer_data )
295
+
307
296
def __getstate__ (self ) -> dict [str , Any ]:
308
297
return {'config' : self .config , 'reasoner' : self .reasoner }
309
298
310
299
def __setstate__ (self , state : dict [str , Any ]):
311
300
self .config = state ['config' ]
312
301
self .reasoner = state ['reasoner' ]
313
302
303
+ self .tokenizer_info = GrammarConfig .tokenizer_info (
304
+ self .config .tokenizer_data )
314
305
self .ctx = None
315
306
self .matchers = []
316
307
self .batch_size = 1
@@ -352,7 +343,7 @@ def __call__(self, input_ids: list[int],
352
343
xgr .GrammarMatcher (self .ctx ) for _ in range (self .batch_size )
353
344
]
354
345
self .token_bitmask = xgr .allocate_token_bitmask (
355
- self .batch_size , self .config .vocab_size )
346
+ self .batch_size , self .tokenizer_info .vocab_size )
356
347
357
348
if not self .prefilled :
358
349
# Have not sampled a token yet
0 commit comments