Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)

# 4. Cascade attention should have a fall back mode
# 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
Expand Down
2 changes: 1 addition & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None:
assert ubatch_slices is None
assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp(
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def dispatch(
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
Expand All @@ -165,7 +165,7 @@ def dispatch(
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

if not use_cascade_attn:
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def _pad_batch_across_dp(
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
Expand Down
49 changes: 37 additions & 12 deletions vllm/v1/worker/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ def _run_ar(
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
Comment on lines +49 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of magic numbers 0, 1, 2, 3, 4 for indexing into the tensor makes the code hard to read and maintain. It's not immediately clear what each index represents without looking at the surrounding code or comments. This pattern is also present in _post_process_cudagraph_mode with tensor[4, :]. This can lead to bugs if the order or size of the tensor changes.

I recommend defining these indices as constants at the module level, for example, using an Enum. This would make the code self-documenting and less error-prone across all functions that use this tensor (_run_ar, _post_process_ubatch, _post_process_dp_padding, _post_process_cudagraph_mode).

For example:

from enum import IntEnum

class DPSync(IntEnum):
    ORIG_NUM_TOKENS_PER_UBATCH = 0
    PADDED_NUM_TOKENS_PER_UBATCH = 1
    SHOULD_UBATCH = 2
    SHOULD_DP_PAD = 3
    CUDAGRAPH_MODE = 4
    TENSOR_SIZE = 5

Then you could use tensor[DPSync.CUDAGRAPH_MODE] instead of tensor[4].

dist.all_reduce(tensor, group=group)
return tensor

Expand Down Expand Up @@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()


def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
"""
Synchronize cudagraph_mode across DP ranks by taking the minimum.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())


def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None]:
) -> tuple[bool, torch.Tensor | None, int]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do.
Expand All @@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including any DP padding.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
Expand All @@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)

Expand Down Expand Up @@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
should_dp_pad,
)

return should_ubatch, num_tokens_after_padding
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)

return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode


def coordinate_batch_across_dp(
Expand All @@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[bool, torch.Tensor | None]:
cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Expand All @@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
Expand All @@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return False, None
return False, None, cudagraph_mode

# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
Expand All @@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded

(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
parallel_config,
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
_synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)
)

return (should_ubatch, num_tokens_after_padding)
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Align coordinate_batch_across_dp unpacking with new return

coordinate_batch_across_dp now returns three values including the synchronized cudagraph mode (return at line 240), but callers such as set_forward_context in forward_context.py (around lines 295–300) and eagle._pad_batch_across_dp in v1/spec_decode/eagle.py (around lines 1261–1269) still unpack only two items. In multi-DP runs where these paths invoke coordinate_batch_across_dp, Python will raise ValueError: too many values to unpack before padding or execution begins, breaking DP execution for forward contexts and EAGLE. Callers need to accept the third element or the function must preserve the previous 2-tuple interface.

Useful? React with 👍 / 👎.

37 changes: 22 additions & 15 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,17 +2764,19 @@ def _determine_batch_execution_and_padding(
)

dispatch_cudagraph = (
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
use_cascade_attn=use_cascade_attn,
uniform_decode=uniform_decode,
disable_full=disable_full,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
)

cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded, use_cascade_attn
)
num_tokens_padded = batch_descriptor.num_tokens

# Extra coordination when running data-parallel since we need to coordinate
Expand All @@ -2789,23 +2791,28 @@ def _determine_batch_execution_and_padding(
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)

should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
cudagraph_mode=cudagraph_mode.value,
)
)

# Extract DP padding if there is any
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())

# Re-dispatch with DP padding
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
# Re-dispatch with DP padding so we have the correct batch_descriptor
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded,
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded
Expand Down