@@ -67,12 +67,16 @@ def __init__(self, kvargs):
6767 self .return_all_prompt_logics = kvargs .get ("return_all_prompt_logics" , False )
6868 assert not (self .is_token_healing and self .return_all_prompt_logics ), "can not be true in same time"
6969 self .data_type = kvargs .get ("data_type" , "float16" )
70+ mtp_step = get_env_start_args ().mtp_step
7071 self .graph_max_batch_size = kvargs .get ("graph_max_batch_size" , 16 )
7172 self .graph_max_batch_size = (
7273 self .graph_max_batch_size // 2
7374 if get_env_start_args ().enable_decode_microbatch_overlap
7475 else self .graph_max_batch_size
7576 )
77+ # mtp 模式下需要修缮对应的最大batch size,为 (mtp_step + 1) 的倍数
78+ self .graph_max_batch_size = self .graph_max_batch_size * (mtp_step + 1 )
79+
7680 self .graph_max_len_in_batch = kvargs .get ("graph_max_len_in_batch" , 8192 )
7781 self .disable_cudagraph = kvargs .get ("disable_cudagraph" , False )
7882 self .quant_type = kvargs .get ("quant_type" , "none" )
@@ -81,7 +85,7 @@ def __init__(self, kvargs):
8185 self .tp_world_size_ = get_dp_world_size ()
8286 self .enable_tpsp_mix_mode = get_env_start_args ().enable_tpsp_mix_mode
8387
84- self .is_deepseekv3_mtp_mode = self .args .mtp_mode == "deepseekv3"
88+ self .is_deepseekv3_mtp_mode = self .args .mtp_mode in [ "deepseekv3_vanilla" , "deepseekv3_eagle" ]
8589
8690 self ._init_datatype ()
8791 self ._init_config ()
@@ -258,6 +262,10 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
258262 infer_state .batch_size = model_input .batch_size
259263 infer_state .total_token_num = model_input .total_token_num
260264 infer_state .max_len_in_batch = model_input .max_len_in_batch
265+ infer_state .max_q_seq_len = model_input .max_q_seq_len
266+ infer_state .max_kv_seq_len = model_input .max_kv_seq_len
267+ infer_state .max_cache_len = model_input .max_cache_len
268+ infer_state .prefix_total_token_num = model_input .prefix_total_token_num
261269 assert model_input .b_req_idx .shape [0 ] == model_input .b_seq_len .shape [0 ]
262270 infer_state .b_req_idx = model_input .b_req_idx
263271 infer_state .b_seq_len = model_input .b_seq_len
@@ -335,16 +343,16 @@ def _prefill(
335343 model_input : ModelInput ,
336344 ):
337345 infer_state = self ._create_inferstate (model_input )
346+ infer_state .init_some_extra_state (self , model_input .input_ids )
338347 init_req_to_token_indexes (
339- self .req_manager .req_to_token_indexs ,
340- model_input .b_req_idx ,
341- model_input .b_seq_len ,
342- infer_state .b_ready_cache_len ,
343- model_input .max_len_in_batch ,
344- infer_state .mem_index ,
348+ req_to_token_indexs = self .req_manager .req_to_token_indexs ,
349+ b_req_idx = infer_state .b_req_idx ,
350+ b_seq_len = infer_state .b_seq_len ,
351+ b_ready_cache_len = infer_state .b_ready_cache_len ,
352+ b_start_loc = infer_state .b_start_loc ,
353+ alloc_mem_index = infer_state .mem_index ,
354+ max_q_seq_len = infer_state .max_q_seq_len ,
345355 )
346-
347- infer_state .init_some_extra_state (self , model_input .input_ids )
348356 return self ._context_forward (model_input .input_ids , infer_state )
349357
350358 def _decode (
@@ -474,26 +482,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
474482 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
475483
476484 infer_state0 = self ._create_inferstate (model_input0 , 0 )
485+ infer_state0 .init_some_extra_state (self , input_ids0 )
477486 init_req_to_token_indexes (
478- self .req_manager .req_to_token_indexs ,
479- model_input0 .b_req_idx ,
480- model_input0 .b_seq_len ,
481- infer_state0 .b_ready_cache_len ,
482- model_input0 .max_len_in_batch ,
483- infer_state0 .mem_index ,
487+ req_to_token_indexs = self .req_manager .req_to_token_indexs ,
488+ b_req_idx = infer_state0 .b_req_idx ,
489+ b_seq_len = infer_state0 .b_seq_len ,
490+ b_ready_cache_len = infer_state0 .b_ready_cache_len ,
491+ b_start_loc = infer_state0 .b_start_loc ,
492+ alloc_mem_index = infer_state0 .mem_index ,
493+ max_q_seq_len = infer_state0 .max_q_seq_len ,
484494 )
485- infer_state0 .init_some_extra_state (self , input_ids0 )
486495
487496 infer_state1 = self ._create_inferstate (model_input1 , 1 )
497+ infer_state1 .init_some_extra_state (self , input_ids1 )
488498 init_req_to_token_indexes (
489- self .req_manager .req_to_token_indexs ,
490- model_input1 .b_req_idx ,
491- model_input1 .b_seq_len ,
492- infer_state1 .b_ready_cache_len ,
493- model_input1 .max_len_in_batch ,
494- infer_state1 .mem_index ,
499+ req_to_token_indexs = self .req_manager .req_to_token_indexs ,
500+ b_req_idx = infer_state1 .b_req_idx ,
501+ b_seq_len = infer_state1 .b_seq_len ,
502+ b_ready_cache_len = infer_state1 .b_ready_cache_len ,
503+ b_start_loc = infer_state1 .b_start_loc ,
504+ alloc_mem_index = infer_state1 .mem_index ,
505+ max_q_seq_len = infer_state1 .max_q_seq_len ,
495506 )
496- infer_state1 .init_some_extra_state (self , input_ids1 )
497507
498508 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
499509 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
@@ -521,7 +531,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
521531 model_input1 .b_req_idx ,
522532 model_input1 .b_mtp_index ,
523533 )
524-
534+ # TODO 动态 mtp fix
525535 assert model_input0 .batch_size == model_input1 .batch_size
526536 assert model_input0 .mem_indexes .is_cuda
527537 assert model_input1 .mem_indexes .is_cuda
@@ -531,6 +541,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
531541
532542 if self .graph is not None and self .graph .can_run (origin_batch_size , max_len_in_batch ):
533543 find_graph_batch_size = self .graph .find_closest_graph_batch_size (origin_batch_size )
544+ # TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不
545+ # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。
534546 padded_model_input0 = self ._create_padded_decode_model_input (model_input0 , find_graph_batch_size )
535547 padded_model_input1 = self ._create_padded_decode_model_input (model_input1 , find_graph_batch_size )
536548 infer_state0 = self ._create_inferstate (padded_model_input0 , 0 )
@@ -568,6 +580,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
568580 input_ids1 = padded_model_input1 .input_ids ,
569581 infer_state1 = infer_state1 ,
570582 )
583+
584+ # TODO 动态 mtp fix
571585 model_output0 = self ._create_unpad_decode_model_output (model_output0 , origin_batch_size = origin_batch_size )
572586 model_output1 = self ._create_unpad_decode_model_output (model_output1 , origin_batch_size = origin_batch_size )
573587 else :
@@ -696,6 +710,10 @@ def _check_max_len_infer(self):
696710 batch_size = 1 ,
697711 total_token_num = total_token_num ,
698712 max_len_in_batch = self .batch_max_tokens ,
713+ max_q_seq_len = self .batch_max_tokens ,
714+ max_kv_seq_len = self .batch_max_tokens ,
715+ max_cache_len = 0 ,
716+ prefix_total_token_num = 0 ,
699717 input_ids = dummy_input_ids ,
700718 mem_indexes = mem_indexes ,
701719 b_req_idx = b_req_idx ,
@@ -766,6 +784,10 @@ def _autotune_warmup(self):
766784 batch_size = 1 ,
767785 total_token_num = total_token_num ,
768786 max_len_in_batch = input_len ,
787+ max_q_seq_len = input_len ,
788+ max_kv_seq_len = input_len ,
789+ max_cache_len = 0 ,
790+ prefix_total_token_num = 0 ,
769791 input_ids = dummy_input_ids ,
770792 mem_indexes = mem_indexes ,
771793 b_req_idx = b_req_idx ,
@@ -822,6 +844,10 @@ def _init_padded_req(self):
822844 batch_size = batch_size ,
823845 total_token_num = total_token_num ,
824846 max_len_in_batch = prefill_input_len ,
847+ max_q_seq_len = prefill_input_len ,
848+ max_kv_seq_len = prefill_input_len ,
849+ max_cache_len = 0 ,
850+ prefix_total_token_num = 0 ,
825851 input_ids = dummy_input_ids ,
826852 mem_indexes = mem_indexes ,
827853 b_req_idx = b_req_idx ,
0 commit comments