@@ -535,7 +535,6 @@ def decode_n_tokens(
535
535
attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536
536
** sampling_kwargs ,
537
537
):
538
- new_tokens , new_probs = [], []
539
538
encountered_eos = False
540
539
for _i in range (
541
540
num_new_tokens - 1
@@ -553,12 +552,10 @@ def decode_n_tokens(
553
552
** sampling_kwargs ,
554
553
)
555
554
input_pos += 1
556
- new_tokens .append (next_token .clone ())
557
- callback (new_tokens [- 1 ], done_generating = _i == num_new_tokens - 2 )
558
- if need_probs or next_prob is None :
555
+ callback (next_token .clone (), done_generating = _i == num_new_tokens - 2 )
556
+ if not need_probs or next_prob is None :
559
557
yield out_token , None
560
558
else :
561
- new_probs .append (next_prob .clone ())
562
559
yield out_token , next_prob .clone ()
563
560
cur_token = next_token
564
561
@@ -585,7 +582,6 @@ def decode_n_tokens(
585
582
dtype = cur_token .dtype ,
586
583
device = cur_token .device ,
587
584
)
588
- new_tokens .append (eos_token .clone ())
589
585
eos_token , next_prob = self .decode_one_token (
590
586
model ,
591
587
eos_token .view (1 , - 1 ),
@@ -788,7 +784,6 @@ def generate(
788
784
input_pos = input_pos + num_added
789
785
next_token = next_tokens [- 1 ]
790
786
else :
791
- generated_tokens = []
792
787
for generated_token , _ in self .decode_n_tokens (
793
788
model ,
794
789
next_token ,
@@ -806,7 +801,6 @@ def generate(
806
801
attention_backend = attention_backend ,
807
802
** sampling_kwargs ,
808
803
):
809
- generated_tokens .append (generated_token .view (- 1 ))
810
804
yield generated_token , None
811
805
812
806
generate_stats = {
0 commit comments