Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def supports_sink(cls) -> bool:
def is_sparse(cls) -> bool:
return False

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType

return attn_type == AttentionType.DECODER

@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
Expand Down Expand Up @@ -171,6 +182,7 @@ def validate_configuration(
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
Expand All @@ -195,6 +207,8 @@ def validate_configuration(
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(
block_size,
use_mla=False,
has_sink=self.has_sink,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
Expand Down
6 changes: 5 additions & 1 deletion vllm/attention/layers/encoder_only_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def __init__(
block_size = 16

underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)

attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""

Expand All @@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
attn_type=attn_type,
)


Expand All @@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
Expand Down Expand Up @@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
Expand All @@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
if not attention_cls:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

Expand Down
10 changes: 10 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def get_valid_backends(
has_sink,
use_sparse,
device_capability,
attn_type,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
Expand All @@ -318,6 +319,7 @@ def get_valid_backends(
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
Expand All @@ -339,7 +341,13 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention import AttentionType

if attn_type is None:
attn_type = AttentionType.DECODER

device_capability = cls.get_device_capability()
assert device_capability is not None

Expand All @@ -356,6 +364,7 @@ def get_attn_backend_cls(
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons = ["ImportError"]
Expand All @@ -379,6 +388,7 @@ def get_attn_backend_cls(
has_sink,
use_sparse,
device_capability,
attn_type,
)
reasons_str = (
"{"
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
return ""
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def get_attn_backend_cls(
use_mla,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout

Expand Down
11 changes: 11 additions & 0 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def get_supported_head_sizes(cls) -> list[int]:
def get_name() -> str:
return "CPU_ATTN"

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType

return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)

@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType

return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)

@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLEX_ATTENTION"

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType

return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)

@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format

In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""

Expand Down