Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def __init__(
self.params_dtype = params_dtype

vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config

# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
Expand Down Expand Up @@ -1320,26 +1321,6 @@ def __init__(
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None

if self.use_dp_chunking:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]

# Note here we use `num_experts` which is logical expert count
if vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, num_experts)

self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)

self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)

@property
def shared_experts(self) -> torch.nn.Module | None:
return None
Expand Down Expand Up @@ -1398,8 +1379,6 @@ def use_flashinfer_cutlass_kernels(self):

@property
def use_dp_chunking(self) -> bool:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
Expand Down Expand Up @@ -1963,12 +1942,40 @@ def set_eplb_state(
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]

def ensure_moe_quant_config(self):
def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None:
self.quant_method.moe_quant_config = (
self.quant_method.get_fused_moe_quant_config(self)
)

if self.moe_quant_config is None:
self.moe_quant_config = self.quant_method.moe_quant_config

def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
return

states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]

moe = self.moe_config

# Note here we use `num_experts` which is logical expert count
if self.vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, moe.num_experts)

self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)

self.batched_router_logits = torch.zeros(
Comment on lines +1971 to +1975

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid re-allocating DP chunk buffers on every forward

The new ensure_dp_chunking_init always recreates self.batched_hidden_states and self.batched_router_logits with fresh torch.zeros tensors whenever it is called, without first checking whether the buffers already exist. Because this helper is now invoked on every forward (and again inside the chunked path), DP runs will allocate and zero large staging tensors twice per step instead of reusing them as the constructor previously did, adding significant GPU allocation/zeroing overhead and memory churn. Consider only creating these tensors when they are None and reusing them thereafter.

Useful? React with 👍 / 👎.

logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)

@staticmethod
def select_experts(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -2199,8 +2206,6 @@ def forward_impl_chunked(
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)

self.ensure_moe_quant_config()

full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
Expand Down Expand Up @@ -2358,7 +2363,8 @@ def forward_impl(
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None

self.ensure_moe_quant_config()
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need not have self.ensure_dp_chunking_init() here as we are doing it anyways in forward_impl_chunked where it is used ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it is opposite, the forward_impl_chunked is duplicated, so I can remove it there, since forward_impl is the only one that calls forward_impl_chunked.


has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
Expand Down