Skip to content

Commit fc3dc82

Browse files
committed
Change condition
Signed-off-by: quic-sanising <[email protected]>
1 parent 83d33ac commit fc3dc82

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

QEfficient/transformers/sampler/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def sampler_forward(
218218

219219
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
220220

221-
if input_ids.shape[1] != spec_length: # Prefill phase, initialize retained states
221+
if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states
222222
repetition_penalty_retain_state_selected = torch.mul(repetition_penalty_retain_state_selected, 0)
223223
presence_penalty_retain_state_selected = torch.mul(presence_penalty_retain_state_selected, 0)
224224
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx

0 commit comments

Comments
 (0)