Skip to content

Commit db35186

Browse files
authored
[Core] Comment out unused code in sampler (vllm-project#7023)
1 parent 660dea1 commit db35186

File tree

1 file changed

+31
-27
lines changed

1 file changed

+31
-27
lines changed

vllm/model_executor/sampling_metadata.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
_SAMPLING_EPS = 1e-5
1515
_SEED_0_REPLACEMENT = 3403598558
16+
# Some triton sampler related code is guarded before it is ready.
17+
_USE_TRITON_SAMPLER = False
1618

1719

1820
@dataclass
@@ -347,14 +349,16 @@ def from_sampling_metadata(
347349
repetition_penalties: List[float] = []
348350
sampling_seeds: List[int] = []
349351
sample_indices: List[int] = []
350-
prompt_best_of: List[int] = []
351352
do_penalties = False
352353
do_top_p_top_k = False
353354
do_min_p = False
354355

355-
# We need one base seed per Triton slice.
356-
seeds_to_generate = (extra_seeds_to_generate +
357-
get_num_triton_sampler_splits(vocab_size))
356+
if _USE_TRITON_SAMPLER:
357+
prompt_best_of: List[int] = []
358+
359+
# We need one base seed per Triton slice.
360+
seeds_to_generate = (extra_seeds_to_generate +
361+
get_num_triton_sampler_splits(vocab_size))
358362

359363
assert sampling_metadata.seq_groups is not None
360364
for seq_group in sampling_metadata.seq_groups:
@@ -366,9 +370,6 @@ def from_sampling_metadata(
366370
r = sampling_params.repetition_penalty
367371
top_p = sampling_params.top_p
368372
min_p = sampling_params.min_p
369-
seed = sampling_params.seed
370-
371-
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
372373

373374
# k should not be greater than the vocab size.
374375
top_k = min(sampling_params.top_k, vocab_size)
@@ -389,8 +390,7 @@ def from_sampling_metadata(
389390
do_penalties = True
390391

391392
is_prompt = seq_group.is_prompt
392-
if (seq_group.is_prompt
393-
and sampling_params.prompt_logprobs is not None):
393+
if (is_prompt and sampling_params.prompt_logprobs is not None):
394394
# For tokens in the prompt that we only need to get
395395
# their logprobs
396396
query_len = seq_group.query_len
@@ -415,23 +415,27 @@ def from_sampling_metadata(
415415
frequency_penalties += [f] * len(seq_ids)
416416
repetition_penalties += [r] * len(seq_ids)
417417

418-
if is_prompt:
419-
prompt_best_of.append(sampling_params.best_of)
420-
query_len = seq_group.query_len
421-
assert query_len is not None
422-
423-
for seq_id in seq_ids:
424-
seq_data = seq_group.seq_data[seq_id]
425-
extra_entropy = extra_entropy or ()
426-
seq_seeds = cls._get_sequence_seeds(
427-
seed,
428-
seq_data.get_len(),
429-
*extra_entropy,
430-
seq_id,
431-
seeds_to_generate=seeds_to_generate,
432-
is_greedy=is_greedy)
433-
sampling_seeds.append(seq_seeds)
434-
sample_indices.extend(seq_group.sample_indices)
418+
if _USE_TRITON_SAMPLER:
419+
if is_prompt:
420+
prompt_best_of.append(sampling_params.best_of)
421+
query_len = seq_group.query_len
422+
assert query_len is not None
423+
424+
seed = sampling_params.seed
425+
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
426+
427+
for seq_id in seq_ids:
428+
seq_data = seq_group.seq_data[seq_id]
429+
extra_entropy = extra_entropy or ()
430+
seq_seeds = cls._get_sequence_seeds(
431+
seed,
432+
seq_data.get_len(),
433+
*extra_entropy,
434+
seq_id,
435+
seeds_to_generate=seeds_to_generate,
436+
is_greedy=is_greedy)
437+
sampling_seeds.append(seq_seeds)
438+
sample_indices.extend(seq_group.sample_indices)
435439

436440
if do_penalties:
437441
for seq_group in sampling_metadata.seq_groups:
@@ -549,7 +553,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
549553
device="cpu",
550554
dtype=torch.long,
551555
pin_memory=pin_memory,
552-
).T.contiguous()
556+
).t().contiguous()
553557

554558
# Because the memory is pinned, we can do non-blocking
555559
# transfer to device.

0 commit comments

Comments
 (0)