Skip to content

Commit fe7f580

Browse files
committed
comments
Signed-off-by: Huamin Li <[email protected]>
1 parent 64d57c3 commit fe7f580

File tree

14 files changed

+68
-6
lines changed

14 files changed

+68
-6
lines changed

vllm/attention/backends/abstract.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def supports_sink(cls) -> bool:
142142
def is_sparse(cls) -> bool:
143143
return False
144144

145+
@classmethod
146+
def supports_attn_type(cls, attn_type: str) -> bool:
147+
"""Check if backend supports a given attention type.
148+
149+
By default, only supports decoder attention.
150+
Backends should override this to support other attention types.
151+
"""
152+
from vllm.attention import AttentionType
153+
154+
return attn_type == AttentionType.DECODER
155+
145156
@classmethod
146157
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
147158
return True
@@ -171,6 +182,7 @@ def validate_configuration(
171182
has_sink: bool,
172183
use_sparse: bool,
173184
device_capability: "DeviceCapability",
185+
attn_type: str,
174186
) -> list[str]:
175187
invalid_reasons = []
176188
if not cls.supports_head_size(head_size):
@@ -195,6 +207,8 @@ def validate_configuration(
195207
invalid_reasons.append("non-sparse not supported")
196208
if not cls.supports_compute_capability(device_capability):
197209
invalid_reasons.append("compute capability not supported")
210+
if not cls.supports_attn_type(attn_type):
211+
invalid_reasons.append(f"attention type {attn_type} not supported")
198212
combination_reason = cls.supports_combination(
199213
head_size,
200214
dtype,

vllm/attention/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def __init__(
291291
block_size,
292292
use_mla=False,
293293
has_sink=self.has_sink,
294+
attn_type=attn_type,
294295
)
295296
else:
296297
self.attn_backend = attn_backend

vllm/attention/layers/encoder_only_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def __init__(
7474
block_size = 16
7575

7676
underlying_attn_backend = get_attn_backend(
77-
head_size, dtype, kv_cache_dtype, block_size
77+
head_size,
78+
dtype,
79+
kv_cache_dtype,
80+
block_size,
81+
attn_type=AttentionType.ENCODER_ONLY,
7882
)
7983

8084
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

vllm/attention/selector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_attn_backend(
7676
use_mla: bool = False,
7777
has_sink: bool = False,
7878
use_sparse: bool = False,
79+
attn_type: str | None = None,
7980
) -> type[AttentionBackend]:
8081
"""Selects which attention backend to use and lazily imports it."""
8182

@@ -94,6 +95,7 @@ def get_attn_backend(
9495
use_mla=use_mla,
9596
has_sink=has_sink,
9697
use_sparse=use_sparse,
98+
attn_type=attn_type,
9799
)
98100

99101

@@ -106,6 +108,7 @@ def _cached_get_attn_backend(
106108
use_mla: bool = False,
107109
has_sink: bool = False,
108110
use_sparse: bool = False,
111+
attn_type: str | None = None,
109112
) -> type[AttentionBackend]:
110113
# Check whether a particular choice of backend was
111114
# previously forced.

vllm/platforms/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def get_attn_backend_cls(
134134
use_mla: bool,
135135
has_sink: bool,
136136
use_sparse: bool,
137+
attn_type: str | None = None,
137138
) -> str:
138139
from vllm.attention.backends.registry import AttentionBackendEnum
139140

vllm/platforms/cuda.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def get_valid_backends(
298298
has_sink,
299299
use_sparse,
300300
device_capability,
301+
attn_type,
301302
) -> tuple[
302303
list[tuple["AttentionBackendEnum", int]],
303304
dict["AttentionBackendEnum", list[str]],
@@ -318,6 +319,7 @@ def get_valid_backends(
318319
has_sink,
319320
use_sparse,
320321
device_capability,
322+
attn_type,
321323
)
322324
except ImportError:
323325
invalid_reasons_i = ["ImportError"]
@@ -339,7 +341,13 @@ def get_attn_backend_cls(
339341
use_mla: bool,
340342
has_sink: bool,
341343
use_sparse: bool,
344+
attn_type: str | None = None,
342345
) -> str:
346+
from vllm.attention import AttentionType
347+
348+
if attn_type is None:
349+
attn_type = AttentionType.DECODER
350+
343351
device_capability = cls.get_device_capability()
344352
assert device_capability is not None
345353

@@ -356,6 +364,7 @@ def get_attn_backend_cls(
356364
has_sink,
357365
use_sparse,
358366
device_capability,
367+
attn_type,
359368
)
360369
except ImportError:
361370
invalid_reasons = ["ImportError"]
@@ -379,6 +388,7 @@ def get_attn_backend_cls(
379388
has_sink,
380389
use_sparse,
381390
device_capability,
391+
attn_type,
382392
)
383393
reasons_str = (
384394
"{"

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def get_attn_backend_cls(
218218
use_mla: bool,
219219
has_sink: bool,
220220
use_sparse: bool,
221+
attn_type: str | None = None,
221222
) -> str:
222223
"""Get the attention backend class of a device."""
223224
return ""

vllm/platforms/rocm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def get_attn_backend_cls(
216216
use_mla,
217217
has_sink,
218218
use_sparse,
219+
attn_type: str | None = None,
219220
) -> str:
220221
from vllm._aiter_ops import rocm_aiter_ops
221222
from vllm.attention.backends.registry import AttentionBackendEnum

vllm/platforms/tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def get_attn_backend_cls(
6161
use_mla: bool,
6262
has_sink,
6363
use_sparse,
64+
attn_type: str | None = None,
6465
) -> str:
6566
from vllm.attention.backends.registry import AttentionBackendEnum
6667

vllm/platforms/xpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_attn_backend_cls(
5151
use_mla: bool,
5252
has_sink: bool,
5353
use_sparse,
54+
attn_type: str | None = None,
5455
) -> str:
5556
from vllm.v1.attention.backends.utils import set_kv_cache_layout
5657

0 commit comments

Comments
 (0)