Skip to content

Commit abbaf53

Browse files
committed
Change size of each sampling parameter to (batch_size, 1)
Signed-off-by: quic-sanising <[email protected]>
1 parent fc3dc82 commit abbaf53

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

QEfficient/transformers/models/modeling_auto.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
258258
}
259259
output_names.append("repetition_penalty_retain_state_RetainedState")
260260

261-
example_inputs["repetition_penalties"] = torch.ones(bs, dtype=torch.float) * 0.5
261+
example_inputs["repetition_penalties"] = torch.ones((bs, 1), dtype=torch.float) * 0.5
262262
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
263263

264264
example_inputs["presence_penalty_retain_state"] = torch.zeros(
@@ -268,22 +268,22 @@ def export(self, export_dir: Optional[str] = None) -> str:
268268
}
269269
output_names.append("presence_penalty_retain_state_RetainedState")
270270

271-
example_inputs["presence_penalties"] = torch.zeros(bs, dtype=torch.float) + 0.5
271+
example_inputs["presence_penalties"] = torch.zeros((bs, 1), dtype=torch.float) + 0.5
272272
dynamic_axes["presence_penalties"] = {0: "batch_size"}
273273

274-
example_inputs["temperatures"] = torch.ones(bs, dtype=torch.float)
274+
example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float)
275275
dynamic_axes["temperatures"] = {0: "batch_size"}
276276

277-
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs,)).to(torch.int32)
277+
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
278278
dynamic_axes["top_ks"] = {0: "batch_size"}
279279

280-
example_inputs["top_ps"] = torch.ones(bs, dtype=torch.float) * 0.80
280+
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80
281281
dynamic_axes["top_ps"] = {0: "batch_size"}
282282

283-
example_inputs["min_ps"] = torch.ones(bs, dtype=torch.float) * 0.99
283+
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.99
284284
dynamic_axes["min_ps"] = {0: "batch_size"}
285285

286-
example_inputs["random_numbers"] = torch.rand(bs, dtype=torch.float)
286+
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
287287
dynamic_axes["random_numbers"] = {0: "batch_size"}
288288

289289
return self._export(

QEfficient/transformers/sampler/sampler.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def sampler_forward(
235235

236236
# Repetition Penalty
237237
if (repetition_penalties != 1.).any():
238-
repetition_penalties = repetition_penalties.unsqueeze(1).repeat(spec_length, vocab_size) # (batch_size,) -> (batch_size * spec_length, vocab_size)
238+
repetition_penalties = repetition_penalties.repeat(spec_length, vocab_size) # (batch_size, 1) -> (batch_size * spec_length, vocab_size)
239239
repetition_penalty_retain_state_selected = repetition_penalty_retain_state_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
240240
repetition_penalties[~repetition_penalty_retain_state_selected.bool()] = 1.0
241241
logits = torch.where(
@@ -244,40 +244,40 @@ def sampler_forward(
244244

245245
# Presence Penalty
246246
if (presence_penalties != 0.).any():
247-
presence_penalties = presence_penalties.unsqueeze(1).repeat(spec_length, 1) # (batch_size,) -> (batch_size * spec_length, 1)
247+
presence_penalties = presence_penalties.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
248248
presence_penalty_retain_state_selected = presence_penalty_retain_state_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
249249
logits -= presence_penalties * presence_penalty_retain_state_selected
250250

251251
# TODO: Frequency Penalty
252252

253253
# Temperature Scaling
254254
if (temperatures != 0).any():
255-
temperatures = temperatures.unsqueeze(1).repeat(spec_length, 1) # (batch_size,) -> (batch_size * spec_length, 1)
255+
temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
256256
logits = torch.where(temperatures != 0, logits / temperatures, logits)
257257

258258
# Top K
259259
# TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any(): skip
260-
topk_values_asc, topk_indices_asc = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1, largest=False) # (batch_size * spec_length, vocab_size)
260+
topk_values_asc, topk_indices_asc = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1, largest=False) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
261261
top_ks[top_ks > Constants.MAX_TOP_K_IDS] = Constants.MAX_TOP_K_IDS # Clip k to max value
262262
# True values in this mask indicate the positions of the non-top K values
263-
topk_mask = torch.arange(topk_values_asc.shape[1]).unsqueeze(0) < (topk_values_asc.size(1) - top_ks.to(torch.long)).unsqueeze(1).repeat(spec_length, 1)
263+
topk_mask = torch.arange(topk_values_asc.shape[1]).unsqueeze(0) < (topk_values_asc.size(1) - top_ks.to(torch.long)).repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
264264
topk_values_asc[topk_mask] = torch.finfo(torch.float16).min
265265

266266
# Top P
267267
# TODO (Optimization): if (top_ps != 1.).any(): skip but will need top_probs for Min P
268-
top_probs = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, vocab_size)
268+
top_probs = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
269269
topk_probs_sum = torch.cumsum(top_probs, dim=1)
270-
top_p_mask = topk_probs_sum <= 1 - top_ps.unsqueeze(1).repeat(spec_length, 1)
270+
top_p_mask = topk_probs_sum <= 1 - top_ps.repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
271271
top_p_mask[:, Constants.MAX_TOP_K_IDS - 1] = False
272272
topk_values_asc[top_p_mask] = torch.finfo(torch.float16).min
273273

274274
# Min P
275275
# TODO (Optimization): if (min_ps != 0.).any(): skip
276-
scaled_min_p = torch.mul(min_ps.repeat(spec_length), top_probs[:, -1]) # (batch_size * spec_length,)
277-
min_p_mask = top_probs < scaled_min_p.unsqueeze(1)
276+
scaled_min_p = torch.mul(min_ps.repeat(spec_length, 1), top_probs[:, -1:]) # (batch_size * spec_length, 1)
277+
min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
278278
topk_values_asc[min_p_mask] = torch.finfo(torch.float16).min
279279

280-
logits = logits.scatter(1, topk_indices_asc, topk_values_asc)
280+
logits = logits.scatter(1, topk_indices_asc, topk_values_asc) # (batch_size * spec_length, vocab_size)
281281

282282
# Softmax
283283
# TODO (Optimization): if (temperatures == 0).all(): skip and perform only greedy sampling
@@ -286,7 +286,7 @@ def sampler_forward(
286286
# Sample the next tokens
287287
# TODO (Optimization): if self.return_pds: skip
288288
greedy_samples = torch.argmax(probs, dim=-1, keepdim=True) # Greedy Sampling
289-
gumbel_noise = -torch.log(-torch.log(random_numbers.unsqueeze(1).repeat(spec_length, 1))) # Gumbel-Max Trick
289+
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
290290
y = probs + gumbel_noise
291291
random_samples = torch.argmax(y, dim=-1, keepdim=True) # Random Sampling
292292
next_tokens = torch.where(temperatures == 0, greedy_samples, random_samples) # (batch_size * spec_length, 1)

0 commit comments

Comments
 (0)