- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[Bugfix] Fix dp_chunking enablement logic in FusedMoE layer #27220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -1099,6 +1099,7 @@ def __init__( | |
| self.params_dtype = params_dtype | ||
|  | ||
| vllm_config = get_current_vllm_config() | ||
| self.vllm_config = vllm_config | ||
|  | ||
| # FIXME (varun): We should have a better way of inferring the activation | ||
| # datatype. This works for now as the tensor datatype entering the MoE | ||
|  | @@ -1320,26 +1321,6 @@ def __init__( | |
| self.batched_hidden_states: torch.Tensor | None = None | ||
| self.batched_router_logits: torch.Tensor | None = None | ||
|  | ||
| if self.use_dp_chunking: | ||
| states_shape: tuple[int, ...] | ||
| logits_shape: tuple[int, ...] | ||
|  | ||
| # Note here we use `num_experts` which is logical expert count | ||
| if vllm_config.parallel_config.enable_dbo: | ||
| states_shape = (2, moe.max_num_tokens, self.hidden_size) | ||
| logits_shape = (2, moe.max_num_tokens, num_experts) | ||
| else: | ||
| states_shape = (moe.max_num_tokens, self.hidden_size) | ||
| logits_shape = (moe.max_num_tokens, num_experts) | ||
|  | ||
| self.batched_hidden_states = torch.zeros( | ||
| states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() | ||
| ) | ||
|  | ||
| self.batched_router_logits = torch.zeros( | ||
| logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() | ||
| ) | ||
|  | ||
| @property | ||
| def shared_experts(self) -> torch.nn.Module | None: | ||
| return None | ||
|  | @@ -1398,8 +1379,6 @@ def use_flashinfer_cutlass_kernels(self): | |
|  | ||
| @property | ||
| def use_dp_chunking(self) -> bool: | ||
| # Route to the chunked forward path using the FlashInfer Cutlass kernel | ||
| # only when data parallelism (DP) is enabled. | ||
| return ( | ||
| self.moe_parallel_config.use_pplx_kernels | ||
| or self.moe_parallel_config.use_deepep_ll_kernels | ||
|  | @@ -1963,12 +1942,40 @@ def set_eplb_state( | |
| self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] | ||
| self.logical_replica_count = logical_replica_count[moe_layer_idx] | ||
|  | ||
| def ensure_moe_quant_config(self): | ||
| def ensure_moe_quant_config_init(self): | ||
| if self.quant_method.moe_quant_config is None: | ||
| self.quant_method.moe_quant_config = ( | ||
| self.quant_method.get_fused_moe_quant_config(self) | ||
| ) | ||
|  | ||
| if self.moe_quant_config is None: | ||
| self.moe_quant_config = self.quant_method.moe_quant_config | ||
|  | ||
| def ensure_dp_chunking_init(self): | ||
| if not self.use_dp_chunking or self.batched_hidden_states is not None: | ||
| return | ||
|  | ||
| states_shape: tuple[int, ...] | ||
| logits_shape: tuple[int, ...] | ||
|  | ||
| moe = self.moe_config | ||
|  | ||
| # Note here we use `num_experts` which is logical expert count | ||
| if self.vllm_config.parallel_config.enable_dbo: | ||
| states_shape = (2, moe.max_num_tokens, self.hidden_size) | ||
| logits_shape = (2, moe.max_num_tokens, moe.num_experts) | ||
| else: | ||
| states_shape = (moe.max_num_tokens, self.hidden_size) | ||
| logits_shape = (moe.max_num_tokens, moe.num_experts) | ||
|  | ||
| self.batched_hidden_states = torch.zeros( | ||
| states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() | ||
| ) | ||
|  | ||
| self.batched_router_logits = torch.zeros( | ||
| logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() | ||
| ) | ||
|  | ||
| @staticmethod | ||
| def select_experts( | ||
| hidden_states: torch.Tensor, | ||
|  | @@ -2199,8 +2206,6 @@ def forward_impl_chunked( | |
| assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) | ||
| assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) | ||
|  | ||
| self.ensure_moe_quant_config() | ||
|  | ||
| full_fused_final_hidden_states = torch.empty_like(full_hidden_states) | ||
| if self.shared_experts is not None: | ||
| full_shared_final_hidden_states = torch.empty_like(full_hidden_states) | ||
|  | @@ -2358,7 +2363,8 @@ def forward_impl( | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| assert self.quant_method is not None | ||
|  | ||
| self.ensure_moe_quant_config() | ||
| self.ensure_moe_quant_config_init() | ||
| self.ensure_dp_chunking_init() | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need not have  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually it is opposite, the forward_impl_chunked is duplicated, so I can remove it there, since forward_impl is the only one that calls forward_impl_chunked. | ||
|  | ||
| has_separate_shared_experts = ( | ||
| not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new
ensure_dp_chunking_initalways recreatesself.batched_hidden_statesandself.batched_router_logitswith freshtorch.zerostensors whenever it is called, without first checking whether the buffers already exist. Because this helper is now invoked on every forward (and again inside the chunked path), DP runs will allocate and zero large staging tensors twice per step instead of reusing them as the constructor previously did, adding significant GPU allocation/zeroing overhead and memory churn. Consider only creating these tensors when they areNoneand reusing them thereafter.Useful? React with 👍 / 👎.