Skip to content

Commit

Permalink
chunked mla
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Feb 4, 2025
1 parent 18016a5 commit 77be9af
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 7 deletions.
7 changes: 2 additions & 5 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,6 @@ def forward(
is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None

if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for MLAImplBase")

# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
Expand Down Expand Up @@ -474,7 +470,8 @@ def forward(
)

if attn_metadata.prefill_metadata is not None:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
return self._forward_prefill(q, k_c_normed, k_pe, kv_cache,
attn_metadata)

if attn_metadata.decode_metadata is not None:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
Expand Down
94 changes: 92 additions & 2 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import triton
import triton.language as tl

from vllm.multimodal import MultiModalPlaceholderMap

try:
Expand Down Expand Up @@ -648,6 +651,63 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)


@triton.jit
def _gather_kv_cache(
# Pointers to inputs and output
seq_start_locs, # (batch_size + 1,)
block_tables, # (batch_size, max_blocks_per_seq)
block_table_stride,
kv_cache, # (num_blocks, block_size, head_size)
kv_page_stride,
kv_out,
CACHE_PAGE_SIZE: tl.constexpr,
CACHE_ENTRY_SIZE: tl.constexpr,
CACHE_ENTRIES_PER_PAGE: tl.constexpr,
CACHE_PAGE_SIZE_POW_2: tl.constexpr,
CACHE_ENTRY_SIZE_POW_2: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""

# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)

seq_start_loc = tl.load(seq_start_locs + g_id)
seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc

pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE)
kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE
block_table = block_tables + g_id * block_table_stride

cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2)
cache_page_mask = cache_page_range < CACHE_PAGE_SIZE
for i in range(pages_to_copy - 1):
page = tl.load(block_table + i)
page_start = kv_cache + page * kv_page_stride
page_data = tl.load(page_start + cache_page_range,
mask=cache_page_mask)
tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range,
page_data,
mask=cache_page_mask)

last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE
last_page = tl.load(block_table + pages_to_copy - 1)
last_page_start = kv_cache + last_page * kv_page_stride

cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2)
cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE
kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE
for i in range(last_page_len):
last_page_data = tl.load(last_page_start + \
i * CACHE_ENTRY_SIZE + cache_entry_range,
mask=cache_entry_mask)
tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range,
last_page_data,
mask=cache_entry_mask)


class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):

def __init__(
Expand Down Expand Up @@ -687,12 +747,42 @@ def __init__(
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, TritonMLAMetadata)
return self._forward_prefill_flash(q, kv_c_normed, k_pe,

if attn_metadata.prefill_metadata.context_lens_tensor is not None and \
max(attn_metadata.prefill_metadata.context_lens_tensor) > 0:
entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1]
kv_c_k_pe_cache = torch.empty(
(entries_total, kv_c_and_k_pe_cache.shape[-1]),
dtype=kv_c_and_k_pe_cache.dtype,
device=kv_c_and_k_pe_cache.device,
)

assert kv_c_and_k_pe_cache.shape[-1] == 576
assert kv_c_and_k_pe_cache.shape[-2] == 16
_gather_kv_cache[(attn_metadata.num_prefills, )](
attn_metadata.prefill_metadata.seq_start_loc,
attn_metadata.prefill_metadata.block_tables,
attn_metadata.prefill_metadata.block_tables.stride(0),
kv_c_and_k_pe_cache,
kv_c_and_k_pe_cache.stride(0),
kv_c_k_pe_cache,
CACHE_PAGE_SIZE=576 * 16,
CACHE_ENTRY_SIZE=576,
CACHE_ENTRIES_PER_PAGE=16,
CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576),
CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16),
)

kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1)
k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1)

return self._forward_prefill_flash(q, kv_c, k_pe,
attn_metadata.seq_start_loc,
attn_metadata.max_prefill_seq_len)

Expand Down

0 comments on commit 77be9af

Please sign in to comment.