Skip to content

Commit 2f925e5

Browse files
[Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828)
Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
1 parent de906b9 commit 2f925e5

File tree

3 files changed

+566
-27
lines changed

3 files changed

+566
-27
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Optional
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.attention.ops.triton_unified_attention import unified_attention
9+
from vllm.platforms import current_platform
10+
11+
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
12+
HEAD_SIZES = [128, 256]
13+
BLOCK_SIZES = [16, 32]
14+
15+
DTYPES = [torch.float16, torch.bfloat16]
16+
QDTYPES = [None, torch.float8_e4m3fn]
17+
# one value large enough to test overflow in index calculation.
18+
# one value small enough to test the schema op check
19+
NUM_BLOCKS = [32768, 2048]
20+
21+
22+
def ref_paged_attn(
23+
query: torch.Tensor,
24+
key_cache: torch.Tensor,
25+
value_cache: torch.Tensor,
26+
query_lens: list[int],
27+
kv_lens: list[int],
28+
block_tables: torch.Tensor,
29+
scale: float,
30+
sliding_window: Optional[int] = None,
31+
soft_cap: Optional[float] = None,
32+
) -> torch.Tensor:
33+
num_seqs = len(query_lens)
34+
block_tables = block_tables.cpu().numpy()
35+
_, block_size, num_kv_heads, head_size = key_cache.shape
36+
37+
outputs: list[torch.Tensor] = []
38+
start_idx = 0
39+
for i in range(num_seqs):
40+
query_len = query_lens[i]
41+
kv_len = kv_lens[i]
42+
q = query[start_idx:start_idx + query_len]
43+
q *= scale
44+
45+
num_kv_blocks = (kv_len + block_size - 1) // block_size
46+
block_indices = block_tables[i, :num_kv_blocks]
47+
48+
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
49+
k = k[:kv_len]
50+
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
51+
v = v[:kv_len]
52+
53+
if q.shape[1] != k.shape[1]:
54+
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
55+
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
56+
attn = torch.einsum("qhd,khd->hqk", q, k).float()
57+
empty_mask = torch.ones(query_len, kv_len)
58+
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
59+
if sliding_window is not None:
60+
sliding_window_mask = torch.triu(empty_mask,
61+
diagonal=kv_len -
62+
(query_len + sliding_window) +
63+
1).bool().logical_not()
64+
mask |= sliding_window_mask
65+
if soft_cap is not None and soft_cap > 0:
66+
attn = soft_cap * torch.tanh(attn / soft_cap)
67+
attn.masked_fill_(mask, float("-inf"))
68+
attn = torch.softmax(attn, dim=-1).to(v.dtype)
69+
out = torch.einsum("hqk,khd->qhd", attn, v)
70+
71+
outputs.append(out)
72+
start_idx += query_len
73+
74+
return torch.cat(outputs, dim=0)
75+
76+
77+
@pytest.mark.parametrize("seq_lens",
78+
[[(1, 1328), (5, 18),
79+
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
80+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
81+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
82+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
83+
@pytest.mark.parametrize("sliding_window", [None, 256])
84+
@pytest.mark.parametrize("dtype", DTYPES)
85+
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
86+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
87+
@pytest.mark.parametrize("q_dtype", QDTYPES)
88+
@torch.inference_mode()
89+
def test_triton_unified_attn(
90+
seq_lens: list[tuple[int, int]],
91+
num_heads: tuple[int, int],
92+
head_size: int,
93+
sliding_window: Optional[int],
94+
dtype: torch.dtype,
95+
block_size: int,
96+
soft_cap: Optional[float],
97+
num_blocks: int,
98+
q_dtype: Optional[torch.dtype],
99+
) -> None:
100+
torch.set_default_device("cuda")
101+
102+
current_platform.seed_everything(0)
103+
num_seqs = len(seq_lens)
104+
query_lens = [x[0] for x in seq_lens]
105+
kv_lens = [x[1] for x in seq_lens]
106+
num_query_heads = num_heads[0]
107+
num_kv_heads = num_heads[1]
108+
assert num_query_heads % num_kv_heads == 0
109+
max_query_len = max(query_lens)
110+
max_kv_len = max(kv_lens)
111+
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
112+
(-1, -1))
113+
scale = head_size**-0.5
114+
115+
query = torch.randn(sum(query_lens),
116+
num_query_heads,
117+
head_size,
118+
dtype=dtype)
119+
key_cache = torch.randn(num_blocks,
120+
block_size,
121+
num_kv_heads,
122+
head_size,
123+
dtype=dtype)
124+
value_cache = torch.randn_like(key_cache)
125+
cu_query_lens = torch.tensor([0] + query_lens,
126+
dtype=torch.int32).cumsum(dim=0,
127+
dtype=torch.int32)
128+
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
129+
130+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
131+
block_tables = torch.randint(0,
132+
num_blocks,
133+
(num_seqs, max_num_blocks_per_seq),
134+
dtype=torch.int32)
135+
136+
output = torch.empty_like(query)
137+
138+
maybe_quantized_query = query
139+
maybe_quantized_key_cache = key_cache
140+
maybe_quantized_value_cache = value_cache
141+
q_descale = None
142+
k_descale = None
143+
v_descale = None
144+
if q_dtype is not None:
145+
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
146+
maybe_quantized_query = query.to(q_dtype)
147+
maybe_quantized_key_cache = key_cache.to(q_dtype)
148+
maybe_quantized_value_cache = value_cache.to(q_dtype)
149+
150+
scale_shape = (num_seqs, num_kv_heads)
151+
q_descale = None # Not yet supported
152+
k_descale = torch.rand(scale_shape, dtype=torch.float32)
153+
v_descale = torch.rand(scale_shape, dtype=torch.float32)
154+
155+
unified_attention(
156+
q=maybe_quantized_query,
157+
k=maybe_quantized_key_cache,
158+
v=maybe_quantized_value_cache,
159+
out=output,
160+
cu_seqlens_q=cu_query_lens,
161+
seqused_k=kv_lens,
162+
max_seqlen_q=max_query_len,
163+
max_seqlen_k=max_kv_len,
164+
softmax_scale=scale,
165+
causal=True,
166+
window_size=window_size,
167+
block_table=block_tables,
168+
softcap=soft_cap if soft_cap is not None else 0,
169+
q_descale=q_descale,
170+
k_descale=k_descale,
171+
v_descale=v_descale,
172+
)
173+
174+
ref_output = ref_paged_attn(
175+
query=query,
176+
key_cache=key_cache,
177+
value_cache=value_cache,
178+
query_lens=query_lens,
179+
kv_lens=kv_lens,
180+
block_tables=block_tables,
181+
scale=scale,
182+
sliding_window=sliding_window,
183+
soft_cap=soft_cap,
184+
)
185+
atol, rtol = 1.5e-2, 1e-2
186+
if q_dtype is not None:
187+
atol, rtol = 1.5e-1, 1.5e-1
188+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
189+
f"{torch.max(torch.abs(output - ref_output))}"

0 commit comments

Comments
 (0)