Skip to content

Commit 8a0897d

Browse files
decode_n_tokens clean up (#1532)
1 parent 701d826 commit 8a0897d

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

torchchat/generate.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,6 @@ def decode_n_tokens(
535535
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
536536
**sampling_kwargs,
537537
):
538-
new_tokens, new_probs = [], []
539538
encountered_eos = False
540539
for _i in range(
541540
num_new_tokens - 1
@@ -553,12 +552,10 @@ def decode_n_tokens(
553552
**sampling_kwargs,
554553
)
555554
input_pos += 1
556-
new_tokens.append(next_token.clone())
557-
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2)
558-
if need_probs or next_prob is None:
555+
callback(next_token.clone(), done_generating=_i == num_new_tokens - 2)
556+
if not need_probs or next_prob is None:
559557
yield out_token, None
560558
else:
561-
new_probs.append(next_prob.clone())
562559
yield out_token, next_prob.clone()
563560
cur_token = next_token
564561

@@ -585,7 +582,6 @@ def decode_n_tokens(
585582
dtype=cur_token.dtype,
586583
device=cur_token.device,
587584
)
588-
new_tokens.append(eos_token.clone())
589585
eos_token, next_prob = self.decode_one_token(
590586
model,
591587
eos_token.view(1, -1),
@@ -788,7 +784,6 @@ def generate(
788784
input_pos = input_pos + num_added
789785
next_token = next_tokens[-1]
790786
else:
791-
generated_tokens = []
792787
for generated_token, _ in self.decode_n_tokens(
793788
model,
794789
next_token,
@@ -806,7 +801,6 @@ def generate(
806801
attention_backend=attention_backend,
807802
**sampling_kwargs,
808803
):
809-
generated_tokens.append(generated_token.view(-1))
810804
yield generated_token, None
811805

812806
generate_stats = {

0 commit comments

Comments
 (0)