Skip to content

Commit

Permalink
[MISC] add arg pad_for_invariant_seq_len
Browse files Browse the repository at this point in the history
Signed-off-by: Mengqing Cao <[email protected]>
  • Loading branch information
MengqingCao committed Feb 5, 2025
1 parent 233df6f commit 9092acd
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ def prepare(
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
pad_for_invariant_seq_len: Optional[bool] = False,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators, cache)
device, generators, cache,
pad_for_invariant_seq_len)
selected_token_indices = async_tensor_h2d(
selected_token_indices,
dtype=torch.long,
Expand Down Expand Up @@ -201,6 +203,7 @@ def _prepare_seq_groups(
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
pad_for_invariant_seq_len: Optional[bool] = False,
) -> Tuple[
List[SequenceGroupToSample],
List[int],
Expand All @@ -219,6 +222,9 @@ def _prepare_seq_groups(
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
pad_for_invariant_seq_len: A flag indicating whether pad is required.
Padding is required when the input tokens/positions of different
batches needed to be aligned to the same length `max_seq_len`.
Returns:
seq_groups: A list of sequence group to sample.
Expand Down Expand Up @@ -265,6 +271,7 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
padding_len: int = 0
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices
Expand Down Expand Up @@ -294,6 +301,8 @@ def _prepare_seq_groups(
query_len = query_lens[i] if query_lens is not None and len(
query_lens) > 0 else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
if pad_for_invariant_seq_len:
padding_len = max(query_lens) - sample_len - prompt_logprob_len

if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
Expand All @@ -311,6 +320,8 @@ def _prepare_seq_groups(
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
if pad_for_invariant_seq_len:
model_output_idx += padding_len
if do_sample:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
Expand Down

0 comments on commit 9092acd

Please sign in to comment.