We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 83d33ac commit fc3dc82Copy full SHA for fc3dc82
QEfficient/transformers/sampler/sampler.py
@@ -218,7 +218,7 @@ def sampler_forward(
218
219
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
220
221
- if input_ids.shape[1] != spec_length: # Prefill phase, initialize retained states
+ if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states
222
repetition_penalty_retain_state_selected = torch.mul(repetition_penalty_retain_state_selected, 0)
223
presence_penalty_retain_state_selected = torch.mul(presence_penalty_retain_state_selected, 0)
224
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx
0 commit comments