Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DICP] Add paged attention. #864

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,15 +1609,20 @@ def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim, num_ke
return op.to_node()

@staticmethod
def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, input_layout="BSH"):
def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, input_layout="BSH", block_table=None, seq_lengths=None, block_size=128):
op = OP(name, "IncreFlashAttention")
op.set_input("query", q)
op.set_dynamic_input("key", kv_input_num, k_list)
op.set_dynamic_input("value", kv_input_num, v_list)
op.set_attr_int("num_heads", head_num)
op.set_attr_float("scale_value", float(1 / math.sqrt(dim)))
if not block_table:
op.set_attr_float("scale_value", float(1 / math.sqrt(dim)))
op.set_attr_int("num_key_value_heads", kv_head_num)
op.set_attr_str("input_layout", input_layout)
if block_table:
op.set_input("block_table", block_table)
op.set_input("actual_seq_lengths", seq_lengths)
op.set_attr_int("block_size", block_size)
return op.to_node()

@staticmethod
Expand Down
29 changes: 26 additions & 3 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,8 @@ def lightllm_rotary_emb(self, x, cos, sin):

seq_len = x_shape[0]
dim = x_shape[2]
if isinstance(dim, torch.fx.proxy.Proxy):
dim = int(sympy.N(dim.node.meta['val']))

cos_sin_shape = self.get_shape_proxy([seq_len, 1, dim // 2])
cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape))
Expand Down Expand Up @@ -1764,7 +1766,7 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim, num_ke
return self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(q_dtype)))
return fa

def incre_flash_attention(self, q, k, v, kv_head_num, head_num, dim):
def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim):
k_list = []
v_list = []
if not isinstance(k, list):
Expand All @@ -1777,7 +1779,7 @@ def incre_flash_attention(self, q, k, v, kv_head_num, head_num, dim):
v_list = v
assert len(k_list) == len(v_list)
kv_input_num = len(k_list)
out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, "BSH"))
out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, "BSH"))
return out

@register_conversion(aten.select_scatter.default)
Expand Down Expand Up @@ -1813,6 +1815,11 @@ def copy_with_offset(self, x, src, start_dim, end_dim):
src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(src_dtype)))
return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src))

@register_conversion(torch.ops.lightllm.copy_with_index.default)
def copy_with_index(self, x, src, dims):
dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1]))
return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src))

@register_conversion(torch.ops.lightllm.flash_attention_inference.default)
def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1):
q_shape = list(q.node.meta['val'].shape)
Expand Down Expand Up @@ -1867,7 +1874,7 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhe
xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape))
xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape))

out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH
out = self.incre_flash_attention(xq, k, v, head, kvhead, dim) # q shape is BSH
out_shape = self.get_shape_proxy([compute_batch, 1, head, dim])
out_shape2 = self.get_shape_proxy([compute_batch, head, dim])
out = self.get_proxy(ascend_op.Reshape, (out, out_shape))
Expand All @@ -1876,3 +1883,19 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhe

res = self.get_proxy(ascend_op.ConcatD, (res, 0))
return res

@register_conversion(torch.ops.lightllm.paged_attention_inference.default)
def paged_attention_inference(self, q, all_k, all_v, q_head_num, dim, kv_head_num, block_table=None, seq_lengths=None, block_size=128):
if isinstance(q_head_num, torch.fx.proxy.Proxy):
q_head_num = int(sympy.N(q_head_num.node.meta['val']))
if isinstance(dim, torch.fx.proxy.Proxy):
dim = int(sympy.N(dim.node.meta['val']))
if isinstance(kv_head_num, torch.fx.proxy.Proxy):
kv_head_num = int(sympy.N(kv_head_num.node.meta['val']))
q = self.get_proxy(ascend_op.Unsqueeze, (q, [1]))
all_k = self.get_proxy(ascend_op.Unsqueeze, (all_k, [1]))
all_v = self.get_proxy(ascend_op.Unsqueeze, (all_v, [1]))
out = self.get_proxy(ascend_op.IncreFlashAttention, (q, [all_k], [all_v], 1,
q_head_num, kv_head_num, dim, "BSH", block_table, seq_lengths, block_size))
out = self.get_proxy(ascend_op.Squeeze, (out, [1]))
return out
58 changes: 58 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ext_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,49 @@ def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_l
res = torch.cat(res)
return res

@torch._custom_op.impl.custom_op('lightllm::paged_attention_inference')
def paged_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, q_head_num: int, dim: int, kv_head_num: int, block_table: Tensor, seq_lengths: Tensor, block_size: int) -> Tensor:
...


@paged_attention_inference.impl_abstract()
def lightllm_paged_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, q_head_num: int, dim: int, kv_head_num: int, block_table: Tensor, seq_lengths: Tensor, block_size: int):
return torch.empty_like(q)

@paged_attention_inference.impl(['cpu', 'cuda'])
def lightllm_paged_attention_inference_impl(q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size):
# q: batch, head, dim
batch = q.shape[0]
head = q_head_num
current_lens = seq_lengths

res = []
compute_batch = 1
for i in range(batch):
current_len = current_lens[i]
kv_seq_len = current_len

k = all_k[:current_len].reshape(compute_batch, kv_seq_len, head, dim)
v = all_v[:current_len].reshape(compute_batch, kv_seq_len, head, dim)

xq = q[i].view(compute_batch, 1, head, dim).transpose(1, 2).transpose(0, 1) # shape: head, batch, 1, dim
bmm_xq = xq.reshape(head * compute_batch, 1, dim).float()
bmm_xk = k.transpose(1, 2).transpose(0, 1).transpose(2, 3).reshape(head * compute_batch, dim, kv_seq_len).float()

# q @ k
out = torch.bmm(bmm_xq, bmm_xk) / math.sqrt(dim)
out = out.reshape(head, compute_batch, 1, -1).reshape(head, compute_batch, -1)

# softmax
out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len
xv = v.transpose(1, 2).float() # shape: batch head, seq_len, dim
out = torch.bmm(out.reshape(compute_batch * head, 1, kv_seq_len), xv.reshape(compute_batch * head, kv_seq_len, dim))

out = out.reshape(compute_batch, head, 1, dim).view(compute_batch, head, dim)
res.append(out)
res = torch.cat(res)
return res


@torch._custom_op.impl.custom_op('lightllm::copy_with_offset')
def copy_with_offset(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor:
Expand All @@ -147,3 +190,18 @@ def lightllm_copy_with_offset_abstract(x: Tensor, src: Tensor, start_dim: int, e
def lightllm_copy_with_offset_impl(x, src, start_dim, end_dim) -> Tensor:
x[start_dim:end_dim] = src
return x

@torch._custom_op.impl.custom_op('lightllm::copy_with_index')
def copy_with_index(x: Tensor, src: Tensor, index: Tensor) -> Tensor:
...


@copy_with_index.impl_abstract()
def lightllm_copy_with_index_abstract(x: Tensor, src: Tensor, index: Tensor) -> Tensor:
return x


@copy_with_index.impl(['cpu', 'cuda'])
def lightllm_copy_with_index_impl(x, src, index) -> Tensor:
x[index] = src
return x
61 changes: 61 additions & 0 deletions dicp/test/op/test_lightllm_paged_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest

from dicp.vendor.AscendGraph import ext_ops
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size):
res = torch.ops.lightllm.paged_attention_inference.default(q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size)
return res


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestLightllmPagedAttention():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size(((10,), (8, 16), (8, 16)), ((10,), (8, 16), (8, 16))), Size(((10,), (16, 32), (2, 32)), ((10,), (16, 32), (2, 32)))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_lightllm_paged_attention(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static

q = torch.randn((1,) + size[1], dtype=dtype)
k = torch.randn(size[0] + size[2], dtype=dtype)
v = torch.randn(size[0] + size[2], dtype=dtype)

q_head_num = size[1][0]
dim = size[1][1]
kv_head_num = size[2][0]
block_table = torch.tensor([[0]], dtype=torch.int32)
seq_lengths = list(size[0])
block_size = 128

dicp_q = q.to(device)
dicp_k = k.to(device)
dicp_v = v.to(device)
dicp_block_table = block_table.to(device)
dicp_seq_lengths = torch.tensor([seq_lengths], device=device, dtype=torch.int64)

if q_head_num != kv_head_num:
repeat = q_head_num / kv_head_num
k = k.repeat(1, repeat, 1)
v = v.repeat(1, repeat, 1)

output = model(q, k, v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size).half().reshape(-1, q_head_num, dim)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_q, dicp_k, dicp_v, q_head_num, dim, kv_head_num, dicp_block_table, dicp_seq_lengths, block_size).reshape(-1, q_head_num, dim)

assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)
Loading