Skip to content

Commit 99eff67

Browse files
authored
[Bugfix][Kernel] Add head size check for attention backend selection (vllm-project#4944)
1 parent 14772ee commit 99eff67

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

vllm/attention/backends/flash_attn.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1010
AttentionMetadata)
1111

12-
_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
13-
1412

1513
class FlashAttentionBackend(AttentionBackend):
1614

15+
@staticmethod
16+
def get_supported_head_sizes() -> List[int]:
17+
return [32, 64, 96, 128, 160, 192, 224, 256]
18+
1719
@staticmethod
1820
def get_name() -> str:
1921
return "flash-attn"
@@ -237,10 +239,12 @@ def __init__(
237239
# paged KV cache.
238240
raise ValueError(
239241
"Sliding window is not supported in FlashAttention.")
240-
if head_size not in _SUPPORTED_HEAD_SIZES:
242+
243+
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
244+
if head_size not in support_head_sizes:
241245
raise ValueError(
242246
f"Head size {head_size} is not supported by FlashAttention. "
243-
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
247+
f"Supported head sizes are: {support_head_sizes}.")
244248

245249
def forward(
246250
self,

vllm/attention/selector.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,21 @@ def get_attn_backend(
3434
sliding_window, dtype, kv_cache_dtype,
3535
block_size)
3636
if backend == _Backend.FLASH_ATTN:
37-
logger.info("Using FlashAttention-2 backend.")
3837
from vllm.attention.backends.flash_attn import ( # noqa: F401
3938
FlashAttentionBackend)
40-
return FlashAttentionBackend
41-
elif backend == _Backend.XFORMERS:
39+
40+
# We check it here not in _which_attn_to_use because we cannot know
41+
# the head size until we import FlashAttentionBackend.
42+
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
43+
if head_size in supported_head_sizes:
44+
logger.info("Using FlashAttention-2 backend.")
45+
return FlashAttentionBackend
46+
logger.info(
47+
"Cannot use FlashAttention-2 backend for head size %d. "
48+
"Using XFormers backend instead.", head_size)
49+
backend = _Backend.XFORMERS
50+
51+
if backend == _Backend.XFORMERS:
4252
logger.info("Using XFormers backend.")
4353
from vllm.attention.backends.xformers import ( # noqa: F401
4454
XFormersBackend)

0 commit comments

Comments
 (0)