Skip to content

Commit 0455c46

Browse files
[Core] Factor out common code in SequenceData and Sequence (vllm-project#8675)
1 parent d4bf085 commit 0455c46

File tree

8 files changed

+64
-97
lines changed

8 files changed

+64
-97
lines changed

tests/samplers/test_sampler.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22
import random
3-
from array import array
43
from typing import Dict, List, Optional, Tuple
54
from unittest.mock import Mock, patch
65

@@ -12,8 +11,7 @@
1211
from vllm.model_executor.layers.sampler import Sampler
1312
from vllm.model_executor.sampling_metadata import SamplingMetadata
1413
from vllm.model_executor.utils import set_random_seed
15-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
16-
SequenceData, SequenceGroupMetadata)
14+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
1715
from vllm.utils import Counter, is_pin_memory_available
1816

1917

@@ -59,9 +57,7 @@ def _do_sample(
5957
SequenceGroupMetadata(
6058
request_id=f"test_{i}",
6159
is_prompt=True,
62-
seq_data={
63-
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
64-
},
60+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
6561
sampling_params=sampling_params,
6662
block_tables={0: [1]},
6763
))
@@ -205,9 +201,8 @@ def create_sampling_params(min_tokens,
205201
return sampling_params
206202

207203
def create_sequence_data(num_input=3, num_generated=0):
208-
seq_data = SequenceData(
209-
array(VLLM_TOKEN_ID_ARRAY_TYPE,
210-
random.choices(range(0, VOCAB_SIZE), k=num_input)))
204+
seq_data = SequenceData.from_seqs(
205+
random.choices(range(0, VOCAB_SIZE), k=num_input))
211206
if num_generated > 0:
212207
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
213208
k=num_generated)
@@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
511506
SequenceGroupMetadata(
512507
request_id=f"test_{i}",
513508
is_prompt=True,
514-
seq_data={
515-
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
516-
},
509+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
517510
sampling_params=sampling_params,
518511
block_tables={0: [1]},
519512
))
@@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
613606
SequenceGroupMetadata(
614607
request_id=f"test_{i}",
615608
is_prompt=True,
616-
seq_data={
617-
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
618-
},
609+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
619610
sampling_params=SamplingParams(
620611
temperature=1,
621612
top_k=top_k,
@@ -699,11 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
699690
SequenceGroupMetadata(
700691
request_id=f"test_{i}",
701692
is_prompt=True,
702-
seq_data={
703-
0:
704-
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
705-
[1, 2, 3]))
706-
},
693+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
707694
sampling_params=sampling_params[i],
708695
block_tables={0: [1]},
709696
))

tests/spec_decode/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from array import array
21
from itertools import count
32
from typing import Callable, Dict, List, Optional
43
from typing import Sequence as GenericSequence
@@ -11,8 +10,7 @@
1110
from vllm.model_executor.layers.sampler import SamplerOutput
1211
from vllm.model_executor.utils import set_random_seed
1312
from vllm.sampling_params import SamplingParams
14-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
15-
CompletionSequenceGroupOutput, Logprob,
13+
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
1614
SequenceData, SequenceGroupMetadata, SequenceOutput)
1715
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
1816
from vllm.worker.cache_engine import CacheEngine
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
138136
request_id=str(i),
139137
is_prompt=len(cont_token_ids) == 0,
140138
seq_data={
141-
i:
142-
SequenceData(
143-
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
144-
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
145-
cont_token_ids[:]),
146-
),
139+
i: SequenceData.from_seqs(prompt_token_ids[:],
140+
cont_token_ids[:]),
147141
},
148142
sampling_params=SamplingParams(temperature=0.0, ),
149143
block_tables={i: block_allocations[i][:]},

tests/test_logits_processor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import random
2-
from array import array
32
from typing import Tuple
43
from unittest.mock import patch
54

@@ -9,8 +8,7 @@
98
from vllm.model_executor.layers.logits_processor import LogitsProcessor
109
from vllm.model_executor.sampling_metadata import SamplingMetadata
1110
from vllm.model_executor.utils import set_random_seed
12-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
13-
SequenceData, SequenceGroupMetadata)
11+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
1412
from vllm.utils import is_pin_memory_available
1513

1614

@@ -71,9 +69,7 @@ def pick_ith(token_ids, logits):
7169
SequenceGroupMetadata(
7270
request_id=f"test_{i}",
7371
is_prompt=True,
74-
seq_data={
75-
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
76-
},
72+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
7773
sampling_params=SamplingParams(temperature=0,
7874
logits_processors=[pick_ith]),
7975
block_tables={0: [1]},

tests/test_sequence.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from array import array
2-
31
import pytest
42

53
from vllm.model_executor.layers.sampler import SamplerOutput
6-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
7-
CompletionSequenceGroupOutput, SequenceData,
4+
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
85
SequenceOutput)
96

107
from .core.utils import create_dummy_prompt
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
5855

5956

6057
def test_sequence_data_prefill():
61-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
58+
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
6259
assert seq_data.get_num_uncomputed_tokens() == 4
6360
assert seq_data.get_num_computed_tokens() == 0
6461
# advance by 2

tests/worker/test_encoder_decoder_model_runner.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import itertools
2-
from array import array
32
from typing import List
43

54
import pytest
65
import torch
76

87
from vllm.engine.arg_utils import EngineArgs
9-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
10-
SequenceData, SequenceGroupMetadata)
8+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
119
from vllm.utils import is_cpu, make_tensor_with_pad
1210
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
1311
from vllm.worker.model_runner import _get_graph_batch_size
@@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
119117
# make sure all tokens fit into one block
120118
seq_len = i % (model_runner.block_size - 1) + 1
121119
seq_lens.append(seq_len)
122-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
123-
range(seq_len)))
120+
seq_data = SequenceData.from_seqs(range(seq_len))
124121
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
125122
encoder_seq_lens.append(encoder_seq_len)
126-
encoder_seq_data = SequenceData(
127-
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
123+
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
128124
seq_group_metadata = SequenceGroupMetadata(
129125
request_id=f"test_{i}",
130126
is_prompt=True,
@@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
317313
for i in range(batch_size):
318314
# make sure all tokens fit into one block
319315
seq_len = i % (model_runner.block_size - 1) + 1
320-
seq_data = SequenceData(
321-
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
316+
seq_data = SequenceData.from_seqs(range(seq_len))
322317
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
323-
encoder_seq_data = SequenceData(
324-
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
318+
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
325319

326320
seq_group_metadata = SequenceGroupMetadata(
327321
request_id=f"test_{i}",
@@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
523517
for i in range(batch_size):
524518
# make sure all tokens fit into one block
525519
seq_len = i % (model_runner.block_size - 1) + 1
526-
seq_data = SequenceData(
527-
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
520+
seq_data = SequenceData.from_seqs(range(seq_len))
528521
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
529-
encoder_seq_data = SequenceData(
530-
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
522+
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
531523
seq_group_metadata = SequenceGroupMetadata(
532524
request_id=f"test_{i}",
533525
is_prompt=False,

tests/worker/test_model_runner.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from array import array
21
from typing import List
32

43
import pytest
@@ -8,8 +7,7 @@
87
init_distributed_environment)
98
from vllm.engine.arg_utils import EngineArgs
109
from vllm.model_executor.sampling_metadata import SamplingMetadata
11-
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
12-
SequenceData, SequenceGroupMetadata)
10+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
1311
from vllm.utils import get_open_port
1412
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
1513

@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
4846
# make sure all tokens fit into one block
4947
seq_len = i % (model_runner.block_size - 1) + 1
5048
seq_lens.append(seq_len)
51-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
52-
range(seq_len)))
49+
seq_data = SequenceData.from_seqs(range(seq_len))
5350
seq_group_metadata = SequenceGroupMetadata(
5451
request_id=f"test_{i}",
5552
is_prompt=True,
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
166163
# make sure all tokens fit into one block
167164
context_len = i % (model_runner.block_size - 1) + 1
168165
context_lens.append(context_len)
169-
seq_data = SequenceData(
170-
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
166+
seq_data = SequenceData.from_seqs(range(context_len))
171167
seq_data.update_num_computed_tokens(context_len)
172168
# Append one token ID since prefill is finished.
173169
seq_data.append_token_id(1, 0)
@@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
326322
# make sure all tokens fit into one block
327323
seq_len = i % (model_runner.block_size - 1) + 1
328324
seq_lens.append(seq_len)
329-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
330-
range(seq_len)))
325+
seq_data = SequenceData.from_seqs(range(seq_len))
331326
seq_group_metadata = SequenceGroupMetadata(
332327
request_id=f"test_{i}",
333328
is_prompt=True,
@@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
343338
for i in range(prefill_batch_size, batch_size):
344339
# make sure all tokens fit into one block
345340
context_len = i % (model_runner.block_size - 1) + 1
346-
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
347-
seq_data = SequenceData(prompt_toks)
341+
seq_data = SequenceData.from_seqs(range(context_len))
348342
seq_data.append_token_id(1, 0)
349343
seq_data.update_num_computed_tokens(context_len)
350344
seq_group_metadata = SequenceGroupMetadata(

vllm/inputs/registry.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
from array import array
32
from collections import UserDict
43
from dataclasses import dataclass
54
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
@@ -22,10 +21,6 @@
2221

2322
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
2423

25-
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
26-
# We cannot import it here because of circular dependencies.
27-
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28-
2924

3025
@dataclass(frozen=True)
3126
class InputContext:
@@ -130,8 +125,7 @@ def _default_dummy_data_factory(
130125
# Avoid circular import
131126
from vllm.sequence import SequenceData
132127

133-
dummy_seq_data = SequenceData(
134-
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
128+
dummy_seq_data = SequenceData.from_counts({0: seq_len})
135129
dummy_multi_modal_data = None
136130

137131
return dummy_seq_data, dummy_multi_modal_data

vllm/sequence.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from array import array
66
from collections import defaultdict
77
from dataclasses import dataclass
8+
from functools import cached_property, reduce
89
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
910
from typing import Sequence as GenericSequence
1011
from typing import Set, Tuple, Union, cast
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
169170
# It is used to compute mrope_position_ids.
170171
_mrope_position_delta: Optional[int] = None
171172

173+
@staticmethod
174+
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
175+
if len(counts_by_token) == 0:
176+
return SequenceData.from_seqs([])
177+
178+
arrs = [
179+
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
180+
for token_id, count in counts_by_token.items()
181+
]
182+
183+
return SequenceData(reduce(array.__add__, arrs))
184+
185+
@staticmethod
186+
def from_seqs(
187+
prompt_token_ids: GenericSequence[int],
188+
output_token_ids: Optional[GenericSequence[int]] = None,
189+
) -> "SequenceData":
190+
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
191+
prompt_token_ids)
192+
193+
if output_token_ids is None:
194+
return SequenceData(prompt_token_ids_arr)
195+
196+
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
197+
output_token_ids)
198+
199+
return SequenceData(prompt_token_ids_arr,
200+
_output_token_ids=output_token_ids_arr)
201+
172202
def __post_init__(self) -> None:
173203
assert self._prompt_token_ids.typecode == "l"
174204
assert self._output_token_ids.typecode == "l"
@@ -370,8 +400,6 @@ def __init__(
370400
self.lora_request = lora_request
371401
self.prompt_adapter_request = prompt_adapter_request
372402
self.from_decoder_prompt = from_decoder_prompt
373-
self._prompt: Optional[str] = None
374-
self._prompt_token_ids: Optional[List[int]] = None
375403

376404
# For decoder-only models, a Sequence is constructed
377405
# from an LLMInputs instance (the `inputs` arg.)
@@ -400,8 +428,7 @@ def __init__(
400428
f"invalid input {inputs}; did you forget the "
401429
"encoder input prompt fields?")
402430

403-
self.data = SequenceData(
404-
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
431+
self.data = SequenceData.from_seqs(self.prompt_token_ids)
405432
self.output_logprobs: SampleLogprobs = []
406433
self.output_text = ""
407434

@@ -422,37 +449,23 @@ def __init__(
422449
def n_blocks(self) -> int:
423450
return (self.get_len() + self.block_size - 1) // self.block_size
424451

425-
@property
452+
@cached_property
426453
def prompt(self) -> Optional[str]:
427-
if self._prompt is not None:
428-
# Reuse precomputed prompt string
429-
return self._prompt
430-
431-
# Select decoder or encoder input prompt str,
432-
# as appropriate
454+
# Select decoder or encoder input prompt str, as appropriate
433455
prompt_key: str = ("prompt"
434456
if self.from_decoder_prompt else "encoder_prompt")
435457

436-
# Cache prompt
437-
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
438-
return self._prompt
458+
return cast(Optional[str], self.inputs.get(prompt_key))
439459

440-
@property
460+
@cached_property
441461
def prompt_token_ids(self) -> List[int]:
442-
if self._prompt_token_ids is not None:
443-
# Reuse precomputed prompt token ids
444-
return self._prompt_token_ids
445-
446-
# Select decoder or encoder input prompt
447-
# token ids, as appropriate
462+
# Select decoder or encoder input prompt token ids, as appropriate
448463
prompt_token_ids_key: str = ("prompt_token_ids"
449464
if self.from_decoder_prompt else
450465
"encoder_prompt_token_ids")
451466

452467
# Cache computed prompt token ids
453-
self._prompt_token_ids = cast(List[int],
454-
self.inputs.get(prompt_token_ids_key))
455-
return self._prompt_token_ids
468+
return cast(List[int], self.inputs.get(prompt_token_ids_key))
456469

457470
@property
458471
def multi_modal_data(self) -> "MultiModalDataDict":

0 commit comments

Comments
 (0)