Skip to content

Commit 09c82e4

Browse files
sleepcooyinfan98
authored andcommitted
fix FlashMLA cudagraph config (sgl-project#4691)
Co-authored-by: yinfan98 <[email protected]>
1 parent 2b969a3 commit 09c82e4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/sglang/srt/layers/attention/flashmla_backend.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
9292
if forward_batch.forward_mode.is_decode_or_idle():
9393
if spec_info is None:
9494
max_seqlen_pad = triton.cdiv(
95-
forward_batch.seq_lens.max().item(), PAGE_SIZE
95+
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
9696
)
9797
block_kv_indices = torch.full(
9898
(bs, max_seqlen_pad),
@@ -206,8 +206,10 @@ def init_forward_metadata_replay_cuda_graph(
206206
):
207207

208208
if forward_mode.is_decode_or_idle():
209+
assert seq_lens_cpu is not None
209210
seq_lens = seq_lens[:bs]
210-
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
211+
seq_lens_cpu = seq_lens_cpu[:bs]
212+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
211213
create_flashmla_kv_indices_triton[(bs,)](
212214
self.req_to_token,
213215
req_pool_indices[:bs],

0 commit comments

Comments
 (0)