@@ -235,7 +235,7 @@ def sampler_forward(
235
235
236
236
# Repetition Penalty
237
237
if (repetition_penalties != 1. ).any ():
238
- repetition_penalties = repetition_penalties .unsqueeze ( 1 ). repeat (spec_length , vocab_size ) # (batch_size,) -> (batch_size * spec_length, vocab_size)
238
+ repetition_penalties = repetition_penalties .repeat (spec_length , vocab_size ) # (batch_size, 1 ) -> (batch_size * spec_length, vocab_size)
239
239
repetition_penalty_retain_state_selected = repetition_penalty_retain_state_selected .repeat (spec_length , 1 ) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
240
240
repetition_penalties [~ repetition_penalty_retain_state_selected .bool ()] = 1.0
241
241
logits = torch .where (
@@ -244,40 +244,40 @@ def sampler_forward(
244
244
245
245
# Presence Penalty
246
246
if (presence_penalties != 0. ).any ():
247
- presence_penalties = presence_penalties .unsqueeze ( 1 ). repeat (spec_length , 1 ) # (batch_size,) -> (batch_size * spec_length, 1)
247
+ presence_penalties = presence_penalties .repeat (spec_length , 1 ) # (batch_size, 1 ) -> (batch_size * spec_length, 1)
248
248
presence_penalty_retain_state_selected = presence_penalty_retain_state_selected .repeat (spec_length , 1 ) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
249
249
logits -= presence_penalties * presence_penalty_retain_state_selected
250
250
251
251
# TODO: Frequency Penalty
252
252
253
253
# Temperature Scaling
254
254
if (temperatures != 0 ).any ():
255
- temperatures = temperatures .unsqueeze ( 1 ). repeat (spec_length , 1 ) # (batch_size,) -> (batch_size * spec_length, 1)
255
+ temperatures = temperatures .repeat (spec_length , 1 ) # (batch_size, 1 ) -> (batch_size * spec_length, 1)
256
256
logits = torch .where (temperatures != 0 , logits / temperatures , logits )
257
257
258
258
# Top K
259
259
# TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any(): skip
260
- topk_values_asc , topk_indices_asc = torch .topk (logits , k = Constants .MAX_TOP_K_IDS , dim = 1 , largest = False ) # (batch_size * spec_length, vocab_size )
260
+ topk_values_asc , topk_indices_asc = torch .topk (logits , k = Constants .MAX_TOP_K_IDS , dim = 1 , largest = False ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS )
261
261
top_ks [top_ks > Constants .MAX_TOP_K_IDS ] = Constants .MAX_TOP_K_IDS # Clip k to max value
262
262
# True values in this mask indicate the positions of the non-top K values
263
- topk_mask = torch .arange (topk_values_asc .shape [1 ]).unsqueeze (0 ) < (topk_values_asc .size (1 ) - top_ks .to (torch .long )).unsqueeze ( 1 ). repeat ( spec_length , 1 )
263
+ topk_mask = torch .arange (topk_values_asc .shape [1 ]).unsqueeze (0 ) < (topk_values_asc .size (1 ) - top_ks .to (torch .long )).repeat ( spec_length , 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS )
264
264
topk_values_asc [topk_mask ] = torch .finfo (torch .float16 ).min
265
265
266
266
# Top P
267
267
# TODO (Optimization): if (top_ps != 1.).any(): skip but will need top_probs for Min P
268
- top_probs = torch .softmax (topk_values_asc , dim = 1 ) # (batch_size * spec_length, vocab_size )
268
+ top_probs = torch .softmax (topk_values_asc , dim = 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS )
269
269
topk_probs_sum = torch .cumsum (top_probs , dim = 1 )
270
- top_p_mask = topk_probs_sum <= 1 - top_ps .unsqueeze ( 1 ). repeat (spec_length , 1 )
270
+ top_p_mask = topk_probs_sum <= 1 - top_ps .repeat (spec_length , 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
271
271
top_p_mask [:, Constants .MAX_TOP_K_IDS - 1 ] = False
272
272
topk_values_asc [top_p_mask ] = torch .finfo (torch .float16 ).min
273
273
274
274
# Min P
275
275
# TODO (Optimization): if (min_ps != 0.).any(): skip
276
- scaled_min_p = torch .mul (min_ps .repeat (spec_length ), top_probs [:, - 1 ]) # (batch_size * spec_length,)
277
- min_p_mask = top_probs < scaled_min_p . unsqueeze ( 1 )
276
+ scaled_min_p = torch .mul (min_ps .repeat (spec_length , 1 ), top_probs [:, - 1 : ]) # (batch_size * spec_length, 1 )
277
+ min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS )
278
278
topk_values_asc [min_p_mask ] = torch .finfo (torch .float16 ).min
279
279
280
- logits = logits .scatter (1 , topk_indices_asc , topk_values_asc )
280
+ logits = logits .scatter (1 , topk_indices_asc , topk_values_asc ) # (batch_size * spec_length, vocab_size)
281
281
282
282
# Softmax
283
283
# TODO (Optimization): if (temperatures == 0).all(): skip and perform only greedy sampling
@@ -286,7 +286,7 @@ def sampler_forward(
286
286
# Sample the next tokens
287
287
# TODO (Optimization): if self.return_pds: skip
288
288
greedy_samples = torch .argmax (probs , dim = - 1 , keepdim = True ) # Greedy Sampling
289
- gumbel_noise = - torch .log (- torch .log (random_numbers .unsqueeze ( 1 ). repeat (spec_length , 1 ))) # Gumbel-Max Trick
289
+ gumbel_noise = - torch .log (- torch .log (random_numbers .repeat (spec_length , 1 ))) # Gumbel-Max Trick
290
290
y = probs + gumbel_noise
291
291
random_samples = torch .argmax (y , dim = - 1 , keepdim = True ) # Random Sampling
292
292
next_tokens = torch .where (temperatures == 0 , greedy_samples , random_samples ) # (batch_size * spec_length, 1)
0 commit comments