@@ -554,7 +554,7 @@ def _make_infer_outputs(self, next_token_ids: torch.LongTensor, running: SeqList
554
554
outputs [session_id ].logits = logits .split (seq_length )[idx ]
555
555
return outputs
556
556
557
- def _make_forward_inputs (self , prefill : bool = None ):
557
+ def _make_forward_inputs (self , prefill : bool = None , enable_empty : bool = False ):
558
558
"""make forward inputs."""
559
559
prefill_interval = self .scheduler_config .prefill_interval
560
560
@@ -609,6 +609,10 @@ def __need_logits(seqs: SeqList):
609
609
if prefill is None :
610
610
prefill = self ._do_prefill ()
611
611
scheduler_output = self .scheduler .schedule (is_prefill = prefill , prealloc_size = prefill_interval )
612
+
613
+ if enable_empty and len (scheduler_output .running ) == 0 :
614
+ return None
615
+
612
616
# schedule decoding if no valid prefill reqs.
613
617
if prefill and len (scheduler_output .running ) == 0 :
614
618
prefill = False
@@ -709,9 +713,13 @@ async def _async_loop_main(self, resp_que: asyncio.Queue, has_runable_event: asy
709
713
forward_inputs = None
710
714
next_running = None
711
715
712
- async def _send_next_inputs (prefill : bool = None ):
716
+ async def _send_next_inputs (prefill : bool = None , enable_empty : bool = False ):
713
717
nonlocal forward_inputs , next_running
714
- forward_inputs = self ._make_forward_inputs (prefill )
718
+ forward_inputs = self ._make_forward_inputs (prefill , enable_empty )
719
+ if forward_inputs is None :
720
+ forward_inputs = None
721
+ next_running = None
722
+ return
715
723
next_running = forward_inputs .pop ('running' )
716
724
await self .executor .forward_async (forward_inputs )
717
725
@@ -730,7 +738,7 @@ async def _prefetch_next_inputs():
730
738
731
739
if enable :
732
740
# send next forward
733
- await _send_next_inputs (prefill )
741
+ await _send_next_inputs (prefill , True )
734
742
735
743
while True :
736
744
if next_running is None :
0 commit comments