Skip to content

Conversation

@alexm-redhat
Copy link
Collaborator

The "use_dp_chunking" in FusedMoE was always false due to control flow mistake => this PR fixes the logic to be enabled for DP runs as intended.

@alexm-redhat alexm-redhat requested a review from mgoin as a code owner October 20, 2025 19:54
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug where use_dp_chunking was always evaluating to false. The fix involves moving the initialization logic for DP chunking buffers from the __init__ method to a lazy initialization method, ensure_dp_chunking_init, which is called during the forward pass. This ensures that all necessary configurations are set before the use_dp_chunking property is evaluated. The overall approach is sound. I've pointed out a performance issue in the new initialization method, which can be made idempotent to avoid unnecessary tensor re-allocations on every forward pass.

Comment on lines 1916 to 1917
if not self.use_dp_chunking:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This initialization method is called on every forward pass. To avoid re-allocating tensors unnecessarily, which can impact performance, it's best to make this method idempotent. You can add a check to see if self.batched_hidden_states has already been initialized.

Suggested change
if not self.use_dp_chunking:
return
if not self.use_dp_chunking or self.batched_hidden_states is not None:
return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment should be applied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1932 to +1975
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)

self.batched_router_logits = torch.zeros(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid re-allocating DP chunk buffers on every forward

The new ensure_dp_chunking_init always recreates self.batched_hidden_states and self.batched_router_logits with fresh torch.zeros tensors whenever it is called, without first checking whether the buffers already exist. Because this helper is now invoked on every forward (and again inside the chunked path), DP runs will allocate and zero large staging tensors twice per step instead of reusing them as the constructor previously did, adding significant GPU allocation/zeroing overhead and memory churn. Consider only creating these tensors when they are None and reusing them thereafter.

Useful? React with 👍 / 👎.

@alexm-redhat alexm-redhat self-assigned this Oct 20, 2025
@alexm-redhat
Copy link
Collaborator Author

@bnellnm could you take a quick look

@bnellnm
Copy link
Contributor

bnellnm commented Oct 20, 2025

Another option is to just change how use_dp_chunking is defined so it doesn't depend on moe_quant_config, e.g.

    @property
    def use_dp_chunking(self) -> bool:
        # Route to the chunked forward path using the FlashInfer Cutlass kernel                                                                                                                          
        # only when data parallelism (DP) is enabled.                                                                                                                                                    
        return (
            self.moe_parallel_config.use_pplx_kernels
            or self.moe_parallel_config.use_deepep_ll_kernels
            or self.moe_parallel_config.use_deepep_hybrid_kernels
            or (self.dp_size > 1 and self.moe.use_flashinfer_cutlass_kernels)
        )

Or maybe even better would be to move the buffer initialization as the PR does but change use_dp_chunking to check the actual format, e.g.

    @property
    def use_dp_chunking(self) -> bool:
        return (
            self.quant_method.fused_experts is not None
            and self.quant_method.fused_experts.prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts
        )

cc @varun-sundar-rabindranath


self.ensure_moe_quant_config()
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need not have self.ensure_dp_chunking_init() here as we are doing it anyways in forward_impl_chunked where it is used ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it is opposite, the forward_impl_chunked is duplicated, so I can remove it there, since forward_impl is the only one that calls forward_impl_chunked.

@varun-sundar-rabindranath
Copy link
Contributor

maybe even better would be to move the buffer initialization as the PR does but change use_dp_chunking to check the actual format, e.g.

    @property
    def use_dp_chunking(self) -> bool:
        return (
            self.quant_method.fused_experts is not None
            and self.quant_method.fused_experts.prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts
        )

I agree with @bnellnm 's suggestion ^ .

@alexm-redhat
Copy link
Collaborator Author

alexm-redhat commented Oct 22, 2025

about use_dp_chunking() I would prefer to leave the function as is since else the checks are very different from what it checks now. It seems to me not a big deal that it depends on moe_quant_config, since anyway we want to inspect the params of the quant config to limit the behaviors (if necessary)

@alexm-redhat
Copy link
Collaborator Author

@bnellnm @mgoin ready for re-review

@mgoin mgoin added bug Something isn't working moe ready ONLY add when PR is ready to merge/full CI is needed labels Oct 22, 2025
@DarkLight1337 DarkLight1337 merged commit 9ef3d5b into main Oct 23, 2025
55 checks passed
@DarkLight1337 DarkLight1337 deleted the dp_fix branch October 23, 2025 16:03
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working moe ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants