Skip to content

Commit ced8c71

Browse files
shihaobaihiworldwzjwangzaijun
authored
deepseek-MTP eagle, topk=1 (#1073)
Co-authored-by: hiworldwzj <[email protected]> Co-authored-by: wangzaijun <[email protected]>
1 parent b23bf4c commit ced8c71

27 files changed

+1045
-538
lines changed

docs/CN/source/tutorial/api_server_args_zh.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,10 @@ MTP 多预测参数
445445

446446
.. option:: --mtp_mode
447447

448-
支持的 mtp 模式,可选值:
448+
支持的 mtp 模式,建议使用 deepseekv3_eagle获得更好的性能体验,可选值:
449449

450-
* ``deepseekv3``
450+
* ``deepseekv3_vanilla``
451+
* ``deepseekv3_eagle``
451452
* ``None``: 不启用 mtp(默认)
452453

453454
.. option:: --mtp_draft_model_dir

docs/EN/source/tutorial/api_server_args_zh.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,10 @@ MTP Multi-Prediction Parameters
442442

443443
.. option:: --mtp_mode
444444

445-
Supported mtp modes, optional values:
445+
Supported mtp modes, it is recommended to use deepseekv3_eagle for better performance, optional values:
446446

447-
* ``deepseekv3``
447+
* ``deepseekv3_vanilla``
448+
* ``deepseekv3_eagle``
448449
* ``None``: Do not enable mtp (default)
449450

450451
.. option:: --mtp_draft_model_dir

lightllm/common/basemodel/basemodel.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,16 @@ def __init__(self, kvargs):
6767
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
6868
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
6969
self.data_type = kvargs.get("data_type", "float16")
70+
mtp_step = get_env_start_args().mtp_step
7071
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
7172
self.graph_max_batch_size = (
7273
self.graph_max_batch_size // 2
7374
if get_env_start_args().enable_decode_microbatch_overlap
7475
else self.graph_max_batch_size
7576
)
77+
# mtp 模式下需要修缮对应的最大batch size,为 (mtp_step + 1) 的倍数
78+
self.graph_max_batch_size = self.graph_max_batch_size * (mtp_step + 1)
79+
7680
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
7781
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
7882
self.quant_type = kvargs.get("quant_type", "none")
@@ -81,7 +85,7 @@ def __init__(self, kvargs):
8185
self.tp_world_size_ = get_dp_world_size()
8286
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
8387

84-
self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3"
88+
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
8589

8690
self._init_datatype()
8791
self._init_config()
@@ -258,6 +262,10 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
258262
infer_state.batch_size = model_input.batch_size
259263
infer_state.total_token_num = model_input.total_token_num
260264
infer_state.max_len_in_batch = model_input.max_len_in_batch
265+
infer_state.max_q_seq_len = model_input.max_q_seq_len
266+
infer_state.max_kv_seq_len = model_input.max_kv_seq_len
267+
infer_state.max_cache_len = model_input.max_cache_len
268+
infer_state.prefix_total_token_num = model_input.prefix_total_token_num
261269
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
262270
infer_state.b_req_idx = model_input.b_req_idx
263271
infer_state.b_seq_len = model_input.b_seq_len
@@ -335,16 +343,16 @@ def _prefill(
335343
model_input: ModelInput,
336344
):
337345
infer_state = self._create_inferstate(model_input)
346+
infer_state.init_some_extra_state(self, model_input.input_ids)
338347
init_req_to_token_indexes(
339-
self.req_manager.req_to_token_indexs,
340-
model_input.b_req_idx,
341-
model_input.b_seq_len,
342-
infer_state.b_ready_cache_len,
343-
model_input.max_len_in_batch,
344-
infer_state.mem_index,
348+
req_to_token_indexs=self.req_manager.req_to_token_indexs,
349+
b_req_idx=infer_state.b_req_idx,
350+
b_seq_len=infer_state.b_seq_len,
351+
b_ready_cache_len=infer_state.b_ready_cache_len,
352+
b_start_loc=infer_state.b_start_loc,
353+
alloc_mem_index=infer_state.mem_index,
354+
max_q_seq_len=infer_state.max_q_seq_len,
345355
)
346-
347-
infer_state.init_some_extra_state(self, model_input.input_ids)
348356
return self._context_forward(model_input.input_ids, infer_state)
349357

350358
def _decode(
@@ -474,26 +482,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
474482
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
475483

476484
infer_state0 = self._create_inferstate(model_input0, 0)
485+
infer_state0.init_some_extra_state(self, input_ids0)
477486
init_req_to_token_indexes(
478-
self.req_manager.req_to_token_indexs,
479-
model_input0.b_req_idx,
480-
model_input0.b_seq_len,
481-
infer_state0.b_ready_cache_len,
482-
model_input0.max_len_in_batch,
483-
infer_state0.mem_index,
487+
req_to_token_indexs=self.req_manager.req_to_token_indexs,
488+
b_req_idx=infer_state0.b_req_idx,
489+
b_seq_len=infer_state0.b_seq_len,
490+
b_ready_cache_len=infer_state0.b_ready_cache_len,
491+
b_start_loc=infer_state0.b_start_loc,
492+
alloc_mem_index=infer_state0.mem_index,
493+
max_q_seq_len=infer_state0.max_q_seq_len,
484494
)
485-
infer_state0.init_some_extra_state(self, input_ids0)
486495

487496
infer_state1 = self._create_inferstate(model_input1, 1)
497+
infer_state1.init_some_extra_state(self, input_ids1)
488498
init_req_to_token_indexes(
489-
self.req_manager.req_to_token_indexs,
490-
model_input1.b_req_idx,
491-
model_input1.b_seq_len,
492-
infer_state1.b_ready_cache_len,
493-
model_input1.max_len_in_batch,
494-
infer_state1.mem_index,
499+
req_to_token_indexs=self.req_manager.req_to_token_indexs,
500+
b_req_idx=infer_state1.b_req_idx,
501+
b_seq_len=infer_state1.b_seq_len,
502+
b_ready_cache_len=infer_state1.b_ready_cache_len,
503+
b_start_loc=infer_state1.b_start_loc,
504+
alloc_mem_index=infer_state1.mem_index,
505+
max_q_seq_len=infer_state1.max_q_seq_len,
495506
)
496-
infer_state1.init_some_extra_state(self, input_ids1)
497507

498508
model_output0, model_output1 = self._overlap_tpsp_context_forward(
499509
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
@@ -521,7 +531,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
521531
model_input1.b_req_idx,
522532
model_input1.b_mtp_index,
523533
)
524-
534+
# TODO 动态 mtp fix
525535
assert model_input0.batch_size == model_input1.batch_size
526536
assert model_input0.mem_indexes.is_cuda
527537
assert model_input1.mem_indexes.is_cuda
@@ -531,6 +541,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
531541

532542
if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch):
533543
find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size)
544+
# TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不
545+
# 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。
534546
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
535547
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
536548
infer_state0 = self._create_inferstate(padded_model_input0, 0)
@@ -568,6 +580,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
568580
input_ids1=padded_model_input1.input_ids,
569581
infer_state1=infer_state1,
570582
)
583+
584+
# TODO 动态 mtp fix
571585
model_output0 = self._create_unpad_decode_model_output(model_output0, origin_batch_size=origin_batch_size)
572586
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
573587
else:
@@ -696,6 +710,10 @@ def _check_max_len_infer(self):
696710
batch_size=1,
697711
total_token_num=total_token_num,
698712
max_len_in_batch=self.batch_max_tokens,
713+
max_q_seq_len=self.batch_max_tokens,
714+
max_kv_seq_len=self.batch_max_tokens,
715+
max_cache_len=0,
716+
prefix_total_token_num=0,
699717
input_ids=dummy_input_ids,
700718
mem_indexes=mem_indexes,
701719
b_req_idx=b_req_idx,
@@ -766,6 +784,10 @@ def _autotune_warmup(self):
766784
batch_size=1,
767785
total_token_num=total_token_num,
768786
max_len_in_batch=input_len,
787+
max_q_seq_len=input_len,
788+
max_kv_seq_len=input_len,
789+
max_cache_len=0,
790+
prefix_total_token_num=0,
769791
input_ids=dummy_input_ids,
770792
mem_indexes=mem_indexes,
771793
b_req_idx=b_req_idx,
@@ -822,6 +844,10 @@ def _init_padded_req(self):
822844
batch_size=batch_size,
823845
total_token_num=total_token_num,
824846
max_len_in_batch=prefill_input_len,
847+
max_q_seq_len=prefill_input_len,
848+
max_kv_seq_len=prefill_input_len,
849+
max_cache_len=0,
850+
prefix_total_token_num=0,
825851
input_ids=dummy_input_ids,
826852
mem_indexes=mem_indexes,
827853
b_req_idx=b_req_idx,

lightllm/common/basemodel/batch_objs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,17 @@ class ModelInput:
1010
batch_size: int
1111
total_token_num: int
1212
max_len_in_batch: int
13-
input_ids: torch.Tensor
14-
b_req_idx: torch.Tensor
15-
b_mtp_index: torch.Tensor
16-
b_seq_len: torch.Tensor
13+
# 在 decode 阶段, 常规模式下, max_q_seq_len 必定是 1,
14+
# 在 mtp 模式下,max_q_seq_len 统计的是一个请求考虑了 mtp 步数的
15+
# 最大长度,实际值是 max([(1 + req.mtp_step) for req in reqs])
16+
max_q_seq_len: int
17+
max_kv_seq_len: int
18+
max_cache_len: int = None
19+
prefix_total_token_num: int = None
20+
input_ids: torch.Tensor = None
21+
b_req_idx: torch.Tensor = None
22+
b_mtp_index: torch.Tensor = None
23+
b_seq_len: torch.Tensor = None
1724
mem_indexes: torch.Tensor = None
1825
is_prefill: bool = False
1926
b_ready_cache_len: torch.Tensor = None

lightllm/common/basemodel/cuda_graph.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ class CudaGraph:
1919
def __init__(self, max_batch_size=8, max_len_in_batch=8192):
2020
self.graph = {}
2121
self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
22+
self.args = get_env_start_args()
23+
self.mtp_step = self.args.mtp_step
2224
self.max_batch_size = max_batch_size
2325
self.graph_max_len_in_batch = max_len_in_batch
24-
self.args = get_env_start_args()
2526
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
2627

2728
# gen cuda graph batch_sizes
2829
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
2930
# and [graph_split_batch_size + graph_grow_step_size,
30-
# graph_split_batch_size + 2 * graph_grow_step_size, ..., self.max_batch_size]
31-
graph_split_batch_size = self.args.graph_split_batch_size
32-
max_batch_size = self.max_batch_size
33-
graph_grow_step_size = self.args.graph_grow_step_size
31+
# if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1)
3432

35-
batch_sizes = [i for i in range(1, graph_split_batch_size + 1)]
33+
graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1)
34+
graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1)
35+
36+
batch_sizes = [i * (self.mtp_step + 1) for i in range(1, graph_split_batch_size + 1)]
3637
for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size):
3738
batch_sizes.append(_batch_size)
3839

3940
batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size]))
4041
batch_sizes.append(max_batch_size)
4142
batch_sizes.sort()
42-
4343
self.cuda_graph_batch_sizes = batch_sizes
4444
assert batch_sizes[-1] == self.max_batch_size
4545
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")
@@ -208,6 +208,8 @@ def warmup(self, model):
208208
batch_size=batch_size,
209209
total_token_num=total_token_num,
210210
max_len_in_batch=max_len_in_batch,
211+
max_q_seq_len=self.mtp_step + 1,
212+
max_kv_seq_len=max_len_in_batch,
211213
input_ids=input_ids,
212214
mem_indexes=mem_indexes,
213215
b_req_idx=b_req_idx,
@@ -265,6 +267,8 @@ def warmup_overlap(self, model):
265267
batch_size=batch_size,
266268
total_token_num=total_token_num,
267269
max_len_in_batch=max_len_in_batch,
270+
max_q_seq_len=self.mtp_step + 1,
271+
max_kv_seq_len=max_len_in_batch,
268272
input_ids=input_ids,
269273
b_mtp_index=b_mtp_index,
270274
mem_indexes=mem_indexes,

lightllm/common/basemodel/infer_struct.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def __init__(self):
2525
# prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值
2626
# decode 阶段指的是每个req的总长 最大值
2727
self.max_len_in_batch: int = None
28+
# max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度
29+
self.max_cache_len: int = None
30+
# prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度
31+
# 的sum值, 其值等于 sum(b_ready_cache_len)
32+
self.prefix_total_token_num: int = None
2833
self.is_prefill: bool = None
2934

3035
self.mem_manager: MemoryManager = None
@@ -72,8 +77,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7277
self.b_kv_seq_len,
7378
self.b1_cu_kv_seq_len,
7479
self.position_ids,
75-
self.max_q_seq_len,
76-
self.max_kv_seq_len,
7780
) = gen_prefill_params(
7881
input_token_num=input_ids.shape[0],
7982
b_ready_cache_len=self.b_ready_cache_len,
@@ -88,7 +91,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
8891
self.b1_cu_kv_seq_len,
8992
self.position_ids,
9093
) = gen_decode_params(self.b_seq_len)
91-
self.max_q_seq_len = 1
9294
# TODO: check the correctness
9395
self.max_kv_seq_len = self.max_len_in_batch
9496
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]

lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
@triton.jit
88
def _fwd_kernel_copy_kv_index_to_req(
9-
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
10-
stride_req_to_token_b, stride_req_to_token_s
9+
req_to_token_indexs, b_req_idx, b_seq_len, memindex, stride_req_to_token_b, stride_req_to_token_s
1110
):
1211
cur_index = tl.program_id(0)
1312
cur_req_idx = tl.load(b_req_idx + cur_index)
@@ -26,8 +25,86 @@ def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):
2625
num_warps = 1
2726

2827
_fwd_kernel_copy_kv_index_to_req[grid](
29-
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
30-
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
28+
req_to_token_indexs,
29+
b_req_idx,
30+
b_seq_len,
31+
memindex,
32+
req_to_token_indexs.stride(0),
33+
req_to_token_indexs.stride(1),
34+
num_warps=num_warps,
35+
num_stages=1,
36+
)
37+
return
38+
39+
40+
@triton.jit
41+
def _fwd_kernel_copy_kv_index_to_req_prefill(
42+
req_to_token_indexs,
43+
b_req_idx,
44+
b_seq_len,
45+
b_ready_cache_len,
46+
b_start_loc,
47+
memindex,
48+
stride_req_to_token_b,
49+
stride_req_to_token_s,
50+
BLOCK: tl.constexpr,
51+
):
52+
53+
block_index = tl.program_id(0)
54+
batch_index = tl.program_id(1)
55+
cur_req_idx = tl.load(b_req_idx + batch_index)
56+
cur_seq_len = tl.load(b_seq_len + batch_index)
57+
cur_ready_cache_len = tl.load(b_ready_cache_len + batch_index)
58+
cur_start_loc = tl.load(b_start_loc + batch_index)
59+
copy_len = cur_seq_len - cur_ready_cache_len
60+
61+
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
62+
block_mask = block_range < copy_len
63+
cur_token_index = tl.load(memindex + cur_start_loc + block_range, mask=block_mask)
64+
dest_offset = (
65+
req_to_token_indexs
66+
+ cur_req_idx * stride_req_to_token_b
67+
+ (cur_ready_cache_len + block_range) * stride_req_to_token_s
68+
)
69+
tl.store(dest_offset, cur_token_index, mask=block_mask)
70+
71+
return
72+
73+
74+
def get_triton_config(max_q_seq_len: int) -> tuple[int, int]:
75+
if max_q_seq_len <= 512:
76+
return 256, 2
77+
elif max_q_seq_len <= 4096:
78+
return 512, 4
79+
else:
80+
return 1024, 8
81+
82+
83+
@torch.no_grad()
84+
def copy_kv_index_to_req_prefill(
85+
req_to_token_indexs: torch.Tensor,
86+
b_req_idx: torch.Tensor,
87+
b_seq_len: torch.Tensor,
88+
b_ready_cache_len: torch.Tensor,
89+
b_start_loc: torch.Tensor,
90+
memindex: torch.Tensor,
91+
max_q_seq_len: int,
92+
):
93+
batch_size = b_req_idx.shape[0]
94+
BLOCK, num_warps = get_triton_config(max_q_seq_len)
95+
grid = (triton.cdiv(max_q_seq_len, BLOCK), batch_size)
96+
num_warps = 1
97+
98+
_fwd_kernel_copy_kv_index_to_req_prefill[grid](
99+
req_to_token_indexs,
100+
b_req_idx,
101+
b_seq_len,
102+
b_ready_cache_len,
103+
b_start_loc,
104+
memindex,
105+
req_to_token_indexs.stride(0),
106+
req_to_token_indexs.stride(1),
107+
BLOCK=BLOCK,
31108
num_warps=num_warps,
32109
num_stages=1,
33110
)

0 commit comments

Comments
 (0)