Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _Backend(enum.Enum):
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()
MIRAGE = enum.auto()


BACKEND_MAP = {
Expand All @@ -49,6 +50,7 @@ class _Backend(enum.Enum):
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
_Backend.MIRAGE: "vllm.v1.attention.backends.mirage.MirageAttentionBackend", # noqa: E501
}


Expand Down
298 changes: 298 additions & 0 deletions vllm/compilation/mirage_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
from collections import defaultdict
from .backends import *
from mirage import MPK, MPKMetadata, MirageModelConfig
import re
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.parallel import ParallelConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.models.utils import extract_layer_index
import torch
from vllm.logger import init_logger

logger = init_logger(__name__)

def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
"""Transfer FX placeholder debug names to model-like dotted names.

Example:
l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_
-> model.layers.17.mlp.gate_up_proj.weight

Notes:
- Tailored for Qwen3-style module names seen in exported FX graphs.
- We do NOT rename the FX node identifiers (dots are not valid in FX names).
Instead, we annotate via node.meta['logical_name'] and return the list.
"""
converted_names = []
s_pattern = re.compile(r"^s\d+$")

for node in placeholders:
name = node.name
if name == 'l_input_ids_':
final_name = 'input_ids'
converted_names.append(final_name)
elif name == 'l_positions_':
final_name = 'positions'
converted_names.append(final_name)
elif s_pattern.match(name): # s72 / s80
converted_names.append(name)
else:
if name.startswith('l_self_modules_'):
name = name.replace('l_self_modules_', '', 1)
if name.endswith('_'):
name = name[:-1]

name = name.replace('_modules_', '.')
name = name.replace('_parameters_', '.')

final_name = 'model.' + name

converted_names.append(final_name)

return converted_names

def build_model_config(
model_config: ModelConfig,
state_dict: dict[str, torch.Tensor],
k_cache_tensors: list[torch.Tensor],
v_cache_tensors: list[torch.Tensor],
position_embeddings_: torch.Tensor,
parallel_config: ParallelConfig,
) -> MirageModelConfig:
whole_dim = position_embeddings_.shape[-1]
cos_tensor_ = position_embeddings_[:, 0:whole_dim//2].unsqueeze(0)
sin_tensor_ = position_embeddings_[:, whole_dim//2:].unsqueeze(0)

cos_tensor = torch.cat([cos_tensor_, cos_tensor_], dim=-1)
sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1)

position_embeddings = (cos_tensor, sin_tensor)
logger.info(f"[Mirage] position_embeddings: {position_embeddings[0].shape}, {position_embeddings[1].shape}")
mirage_model_config = MirageModelConfig(
# model architecture
hidden_size=model_config.get_hidden_size(),

Check failure on line 73 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:73:89: E501 Line too long (112 > 88)

Check failure on line 73 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/compilation/mirage_backend.py:73:17: G004 Logging statement uses f-string
intermediate_size=getattr(model_config.hf_text_config, "intermediate_size", 0),
vocab_size=model_config.get_vocab_size(),
local_num_q_heads=model_config.get_num_attention_heads(parallel_config),
local_num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_dim=model_config.get_head_size(),
num_layers=getattr(model_config.hf_text_config, "num_hidden_layers", 0),
# kv cache
k_cache=k_cache_tensors,
v_cache=v_cache_tensors,
# position embeddings
position_embeddings=position_embeddings,
# model weights
state_dict=state_dict,
with_lm_head=False,
)
return mirage_model_config

def build_mpk_metadata(
vllm_config: VllmConfig,
args: list[Any],
transfered_tensor_names: list[str],
) -> MPKMetadata:
forward_context = get_forward_context()
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
# For now we assume only one attention group
attn_metadata = list(forward_context.attn_metadata.values())[0]

static_forward_context = forward_context.no_compile_layers # layer names to layers
k_cache_tensors = []
v_cache_tensors = []
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in static_forward_context.keys():
index2name[extract_layer_index(layer_name, 1)].append(layer_name)

for layer_index in sorted(index2name.keys()):

Check failure on line 112 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM118)

vllm/compilation/mirage_backend.py:112:9: SIM118 Use `key in dict` instead of `key in dict.keys()`
layer_names = index2name[layer_index]
assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported"
layer_name = layer_names[0]
# logger.info(f"{layer_index} {layer_name}: attention num: {len(static_forward_context[layer_name].kv_cache)}; kv_cache.shape: {static_forward_context[layer_name].kv_cache[0].shape}")
k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0])

Check failure on line 117 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:117:89: E501 Line too long (99 > 88)
v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1])
# kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size)

Check failure on line 119 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:119:89: E501 Line too long (191 > 88)

state_dict = {}
input_token_ids = None

Check failure on line 122 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:122:89: E501 Line too long (99 > 88)
positions_tensor = None
position_embeddings = None
for arg, name in zip(args, transfered_tensor_names):
if name == 'input_ids':
input_token_ids = arg
elif name == 'positions':
positions_tensor = arg
elif "cos_sin_cache" in name:
position_embeddings = arg
else:
state_dict[name] = arg

mirage_model_config = build_model_config(
model_config,
state_dict,
k_cache_tensors,
v_cache_tensors,
position_embeddings,
parallel_config,
)
mpk_metadata = MPKMetadata(
mode = "online",
# total_num_requests
# num_remote_schedulers: int = 0
max_seq_length = model_config.max_model_len,
max_num_batched_requests = scheduler_config.max_num_seqs,
max_num_batched_tokens = scheduler_config.max_num_batched_tokens,
max_num_pages = cache_config.num_gpu_blocks,
page_size = cache_config.block_size,
# max_sm_num: int = 108
device = "cuda",
# # model
weight_from_model = False,
model_name = model_config.model,
# model_path: Optional[str] = None
# multi device support
world_size = parallel_config.world_size,
rank = parallel_config.rank,
# # Meta tensors
step = positions_tensor,
# tokens: Optional[torch.Tensor] = None
input_tokens = input_token_ids,
# output_tokens: Optional[torch.Tensor] = None
# num_new_tokens: Optional[torch.Tensor] = None
# prompt_lengths: Optional[torch.Tensor] = None
qo_indptr_buffer = attn_metadata.qo_indptr_gpu,
paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu,
paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu,
paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu,
# kv cache tensors, weights and model config
model_config=mirage_model_config,
# # profiling
# profiler_tensor: Optional[torch.Tensor] = None
# trace_name: Optional[str] = None
# # spec decode config
# spec_decode_config: Optional[object] = None
)
return mpk_metadata

class MirageBackend:
"""The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
where we customize the compilation.

The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.

Comment on lines +187 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for MirageBackend states that its major work is to split the graph into piecewise graphs. However, the implementation does not perform any graph splitting; it seems to compile the graph it receives as a whole. This docstring appears to be copied from VllmBackend and is misleading. Please update it to accurately describe what MirageBackend does.

This backend also adds the PostGradPassManager to Inductor config,
which handles the post-grad passes.
"""

vllm_config: VllmConfig
compilation_config: CompilationConfig
_called: bool = False
# the graph we compiled
graph: fx.GraphModule

input_buffers: list[torch.Tensor]

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
logger.info("[Mirage] Calling MirageBackend init!")
self.prefix = prefix or model_tag

# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()

self.input_buffers = []

self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.model_config = vllm_config.model_config
self.model_name = vllm_config.model_config.model


def __call__(
self, graph: fx.GraphModule, example_inputs
) -> Any:

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
from .monitor import torch_compile_start_time

# TODO: remove this after debugging
# try:
# src = graph.print_readable(print_output=False)
# except Exception:
# src = str(graph)
# try:
# with open('mirage_backends_graph.txt', 'w') as f:
# logger.info('Writing readable FX graph to mirage_backends_graph.txt')
# f.write(src)
# logger.info('Readable FX graph written to mirage_backends_graph.txt')
# except Exception:
# logger.exception('Failed to write mirage_backends_graph.txt')

dynamo_time = time.time() - torch_compile_start_time
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
self.compilation_config.compilation_time += dynamo_time

# we control the compilation process, each instance can only be
# called once
assert not self._called, "MirageBackend can only be called once"

placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder']
assert len(placeholders) == len(example_inputs)

transfered_tensor_names = transfer_tensor_names(placeholders)


self._called = True
self.compiled = False

def compile_or_call(*args):
dumb_run_called = (get_forward_context().attn_metadata is None)
if dumb_run_called:
model_config = self.vllm_config.model_config
dtype = model_config.dtype
hidden_size = model_config.get_hidden_size()
output_tensor = torch.zeros(1, hidden_size, device='cuda', dtype=dtype)
logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!")

return (output_tensor,)

Check failure on line 276 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:276:89: E501 Line too long (137 > 88)

Check failure on line 276 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/compilation/mirage_backend.py:276:29: G004 Logging statement uses f-string
if not self.compiled:
# Compile only at the first call -- when we get real tensors
logger.info("[Mirage] Calling compile_or_call for the first time, compiling......!")
mpk_metadata = build_mpk_metadata(
self.vllm_config,
args,

Check failure on line 282 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:282:89: E501 Line too long (100 > 88)
transfered_tensor_names,
)
logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}")
self.mpk = MPK(mpk_metadata)
self.mpk.build()
self.mpk.compile(output_dir=os.path.join(os.path.dirname(__file__), "mirage_backend_output"))

Check failure on line 288 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/compilation/mirage_backend.py:288:29: G004 Logging statement uses f-string

self.compiled = True

logger.info(f"[Mirage] Calling the compiled result...")
return self.mpk()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we try to pass at least the input tensors to mpk here? Or do those have to be static too? If yes, I'm not sure that's guaranteed by the GPUModelRunner


# return VllmSerializableFunction(
# graph, example_inputs, self.prefix, compile_or_call
# )
return compile_or_call
3 changes: 3 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
]:
if self.backend in torch_backends:
return self.backend
if self.backend == "mirage":
from vllm.compilation.mirage_backend import MirageBackend
return MirageBackend(vllm_config)
return resolve_obj_by_qualname(self.backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return resolve_obj_by_qualname(self.backend)
if self.backend == "mirage":
from vllm.compilation.mirage_backend import MirageBackend
return MirageBackend(vllm_config)
return resolve_obj_by_qualname(self.backend)


assert self.mode == CompilationMode.VLLM_COMPILE
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"XFORMERS",
"ROCM_ATTN",
"ROCM_AITER_UNIFIED_ATTN",
"MIRAGE"
]
if (
envs.is_set("VLLM_ATTENTION_BACKEND")
Expand Down
6 changes: 5 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def get_attn_backend_cls(
)
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501

MIRAGE_V1 = "vllm.v1.attention.backends.mirage.MirageAttentionBackend" # noqa: E501

use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
"fp8"
)
Expand All @@ -365,6 +366,9 @@ def get_attn_backend_cls(
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend on V1 engine.")
return XFORMERS_V1
elif selected_backend == _Backend.MIRAGE:
logger.info_once("Using Mirage backend on V1 engine.")
return MIRAGE_V1

from vllm.attention.selector import is_attn_backend_supported

Expand Down
Loading