Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_torch_equal_or_newer

if TYPE_CHECKING:
Expand Down Expand Up @@ -705,6 +706,9 @@ def __post_init__(self) -> None:
if self.backend == "":
self.backend = current_platform.simple_compile_backend

# Gets recomputed in the model runner but compute it here for testing.
self.post_init_cudagraph_sizes()

def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
Initialize the backend for the compilation config from a vllm config.
Expand Down Expand Up @@ -773,20 +777,6 @@ def post_init_cudagraph_sizes(self) -> None:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size

# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end

def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
Expand Down Expand Up @@ -922,3 +912,43 @@ def custom_op_log_check(self):
enable_str,
op,
)

def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int):
if not self.cudagraph_capture_sizes or multiple_of <= 1:
return

assert self.max_cudagraph_capture_size is not None
rounded_sizes = sorted(
set(
round_up(size, multiple_of)
for size in self.cudagraph_capture_sizes
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
)
)

if len(rounded_sizes) == 0:
logger.warning(
"No valid cudagraph sizes after rounding to multiple of "
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
multiple_of,
)
return

self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes

def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
18 changes: 18 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4332,6 +4332,24 @@ def _check_and_update_cudagraph_mode(
"and make sure compilation mode is VLLM_COMPILE"
)

# if we have dedicated decode cudagraphs, and spec-decode is enabled,
# we need to adjust the cudagraph sizes to be a multiple of the uniform
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have seperate cudagraph capture
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
and self.uniform_decode_query_len > 1
):
self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of(
self.uniform_decode_query_len
)
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes

self.compilation_config.compute_bs_to_padded_graph_size()

# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.cudagraph_dispatcher.initialize_cudagraph_keys(
Expand Down