13
13
14
14
_SAMPLING_EPS = 1e-5
15
15
_SEED_0_REPLACEMENT = 3403598558
16
+ # Some triton sampler related code is guarded before it is ready.
17
+ _USE_TRITON_SAMPLER = False
16
18
17
19
18
20
@dataclass
@@ -347,14 +349,16 @@ def from_sampling_metadata(
347
349
repetition_penalties : List [float ] = []
348
350
sampling_seeds : List [int ] = []
349
351
sample_indices : List [int ] = []
350
- prompt_best_of : List [int ] = []
351
352
do_penalties = False
352
353
do_top_p_top_k = False
353
354
do_min_p = False
354
355
355
- # We need one base seed per Triton slice.
356
- seeds_to_generate = (extra_seeds_to_generate +
357
- get_num_triton_sampler_splits (vocab_size ))
356
+ if _USE_TRITON_SAMPLER :
357
+ prompt_best_of : List [int ] = []
358
+
359
+ # We need one base seed per Triton slice.
360
+ seeds_to_generate = (extra_seeds_to_generate +
361
+ get_num_triton_sampler_splits (vocab_size ))
358
362
359
363
assert sampling_metadata .seq_groups is not None
360
364
for seq_group in sampling_metadata .seq_groups :
@@ -366,9 +370,6 @@ def from_sampling_metadata(
366
370
r = sampling_params .repetition_penalty
367
371
top_p = sampling_params .top_p
368
372
min_p = sampling_params .min_p
369
- seed = sampling_params .seed
370
-
371
- is_greedy = sampling_params .sampling_type == SamplingType .GREEDY
372
373
373
374
# k should not be greater than the vocab size.
374
375
top_k = min (sampling_params .top_k , vocab_size )
@@ -389,8 +390,7 @@ def from_sampling_metadata(
389
390
do_penalties = True
390
391
391
392
is_prompt = seq_group .is_prompt
392
- if (seq_group .is_prompt
393
- and sampling_params .prompt_logprobs is not None ):
393
+ if (is_prompt and sampling_params .prompt_logprobs is not None ):
394
394
# For tokens in the prompt that we only need to get
395
395
# their logprobs
396
396
query_len = seq_group .query_len
@@ -415,23 +415,27 @@ def from_sampling_metadata(
415
415
frequency_penalties += [f ] * len (seq_ids )
416
416
repetition_penalties += [r ] * len (seq_ids )
417
417
418
- if is_prompt :
419
- prompt_best_of .append (sampling_params .best_of )
420
- query_len = seq_group .query_len
421
- assert query_len is not None
422
-
423
- for seq_id in seq_ids :
424
- seq_data = seq_group .seq_data [seq_id ]
425
- extra_entropy = extra_entropy or ()
426
- seq_seeds = cls ._get_sequence_seeds (
427
- seed ,
428
- seq_data .get_len (),
429
- * extra_entropy ,
430
- seq_id ,
431
- seeds_to_generate = seeds_to_generate ,
432
- is_greedy = is_greedy )
433
- sampling_seeds .append (seq_seeds )
434
- sample_indices .extend (seq_group .sample_indices )
418
+ if _USE_TRITON_SAMPLER :
419
+ if is_prompt :
420
+ prompt_best_of .append (sampling_params .best_of )
421
+ query_len = seq_group .query_len
422
+ assert query_len is not None
423
+
424
+ seed = sampling_params .seed
425
+ is_greedy = sampling_params .sampling_type == SamplingType .GREEDY
426
+
427
+ for seq_id in seq_ids :
428
+ seq_data = seq_group .seq_data [seq_id ]
429
+ extra_entropy = extra_entropy or ()
430
+ seq_seeds = cls ._get_sequence_seeds (
431
+ seed ,
432
+ seq_data .get_len (),
433
+ * extra_entropy ,
434
+ seq_id ,
435
+ seeds_to_generate = seeds_to_generate ,
436
+ is_greedy = is_greedy )
437
+ sampling_seeds .append (seq_seeds )
438
+ sample_indices .extend (seq_group .sample_indices )
435
439
436
440
if do_penalties :
437
441
for seq_group in sampling_metadata .seq_groups :
@@ -549,7 +553,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
549
553
device = "cpu" ,
550
554
dtype = torch .long ,
551
555
pin_memory = pin_memory ,
552
- ).T .contiguous ()
556
+ ).t () .contiguous ()
553
557
554
558
# Because the memory is pinned, we can do non-blocking
555
559
# transfer to device.
0 commit comments