-
-
Couldn't load subscription status.
- Fork 10.9k
[Bugfix] Fix dp_chunking enablement logic in FusedMoE layer #27220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a bug where use_dp_chunking was always evaluating to false. The fix involves moving the initialization logic for DP chunking buffers from the __init__ method to a lazy initialization method, ensure_dp_chunking_init, which is called during the forward pass. This ensures that all necessary configurations are set before the use_dp_chunking property is evaluated. The overall approach is sound. I've pointed out a performance issue in the new initialization method, which can be made idempotent to avoid unnecessary tensor re-allocations on every forward pass.
| if not self.use_dp_chunking: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This initialization method is called on every forward pass. To avoid re-allocating tensors unnecessarily, which can impact performance, it's best to make this method idempotent. You can add a check to see if self.batched_hidden_states has already been initialized.
| if not self.use_dp_chunking: | |
| return | |
| if not self.use_dp_chunking or self.batched_hidden_states is not None: | |
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment should be applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| self.batched_hidden_states = torch.zeros( | ||
| states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() | ||
| ) | ||
|
|
||
| self.batched_router_logits = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid re-allocating DP chunk buffers on every forward
The new ensure_dp_chunking_init always recreates self.batched_hidden_states and self.batched_router_logits with fresh torch.zeros tensors whenever it is called, without first checking whether the buffers already exist. Because this helper is now invoked on every forward (and again inside the chunked path), DP runs will allocate and zero large staging tensors twice per step instead of reusing them as the constructor previously did, adding significant GPU allocation/zeroing overhead and memory churn. Consider only creating these tensors when they are None and reusing them thereafter.
Useful? React with 👍 / 👎.
|
@bnellnm could you take a quick look |
|
Another option is to just change how Or maybe even better would be to move the buffer initialization as the PR does but change |
|
|
||
| self.ensure_moe_quant_config() | ||
| self.ensure_moe_quant_config_init() | ||
| self.ensure_dp_chunking_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need not have self.ensure_dp_chunking_init() here as we are doing it anyways in forward_impl_chunked where it is used ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it is opposite, the forward_impl_chunked is duplicated, so I can remove it there, since forward_impl is the only one that calls forward_impl_chunked.
I agree with @bnellnm 's suggestion ^ . |
|
about use_dp_chunking() I would prefer to leave the function as is since else the checks are very different from what it checks now. It seems to me not a big deal that it depends on moe_quant_config, since anyway we want to inspect the params of the quant config to limit the behaviors (if necessary) |
Signed-off-by: Alexander Matveev <[email protected]>
…ject#27220) Signed-off-by: Alexander Matveev <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
…ject#27220) Signed-off-by: Alexander Matveev <[email protected]>
…ject#27220) Signed-off-by: Alexander Matveev <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…ject#27220) Signed-off-by: Alexander Matveev <[email protected]> Signed-off-by: 0xrushi <[email protected]>
The "use_dp_chunking" in FusedMoE was always false due to control flow mistake => this PR fixes the logic to be enabled for DP runs as intended.