Skip to content

Commit 9528a74

Browse files
authored
fix for small cache-max-entry-count (InternLM#3221)
* fix for small cache-max-entry-count * fix kernel
1 parent f87b1ed commit 9528a74

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def _make_infer_outputs(self, next_token_ids: torch.LongTensor, running: SeqList
554554
outputs[session_id].logits = logits.split(seq_length)[idx]
555555
return outputs
556556

557-
def _make_forward_inputs(self, prefill: bool = None):
557+
def _make_forward_inputs(self, prefill: bool = None, enable_empty: bool = False):
558558
"""make forward inputs."""
559559
prefill_interval = self.scheduler_config.prefill_interval
560560

@@ -609,6 +609,10 @@ def __need_logits(seqs: SeqList):
609609
if prefill is None:
610610
prefill = self._do_prefill()
611611
scheduler_output = self.scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval)
612+
613+
if enable_empty and len(scheduler_output.running) == 0:
614+
return None
615+
612616
# schedule decoding if no valid prefill reqs.
613617
if prefill and len(scheduler_output.running) == 0:
614618
prefill = False
@@ -709,9 +713,13 @@ async def _async_loop_main(self, resp_que: asyncio.Queue, has_runable_event: asy
709713
forward_inputs = None
710714
next_running = None
711715

712-
async def _send_next_inputs(prefill: bool = None):
716+
async def _send_next_inputs(prefill: bool = None, enable_empty: bool = False):
713717
nonlocal forward_inputs, next_running
714-
forward_inputs = self._make_forward_inputs(prefill)
718+
forward_inputs = self._make_forward_inputs(prefill, enable_empty)
719+
if forward_inputs is None:
720+
forward_inputs = None
721+
next_running = None
722+
return
715723
next_running = forward_inputs.pop('running')
716724
await self.executor.forward_async(forward_inputs)
717725

@@ -730,7 +738,7 @@ async def _prefetch_next_inputs():
730738

731739
if enable:
732740
# send next forward
733-
await _send_next_inputs(prefill)
741+
await _send_next_inputs(prefill, True)
734742

735743
while True:
736744
if next_running is None:

lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ def fused_moe_blocked_f8_kernel(
116116
as_ptrs = A_scale + offs_am * stride_asm
117117
bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn
118118

119-
acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
119+
acc_scale = tl.load(as_ptrs, mask=mask_sid, other=1.0) * tl.load(bs_ptrs)
120120
acc_ratio = 1 / acc_scale
121121
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
122122
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
123123
# load scales
124124
k_start = (k + 1) * BLOCK_SIZE_K
125125
offs_ksa = k_start // group_ak
126126
offs_ksb = k_start // group_bk
127-
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=k_start < K, other=1.0)
127+
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid and k_start < K, other=1.0)
128128
b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)
129129

130130
# load ab

0 commit comments

Comments
 (0)