Skip to content

feat: Wx/modify ops for npu #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 73 additions & 57 deletions deeplink_ext/internevo_ops/_flash_attention_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@ def flash_attn_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

seqlen_q = q.shape[1]
seqlen_k = k.shape[1]
head_num = q.shape[-2]

if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048:
sparse_mode = 0
else:
sparse_mode = 2

seqlen_q = min(seqlen_q, 2048)
seqlen_k = min(seqlen_k, 2048)
assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k."
sparse_mode = 2 if seqlen_q > 2048 else 0
seqlen = min(seqlen_q, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
Expand Down Expand Up @@ -81,25 +82,28 @@ def flash_attn_varlen_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_num = q.shape[-2]

cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()
seqlen_q = min(max_seqlen_q, 2048)
seqlen_k = min(max_seqlen_k, 2048)

if max_seqlen_q < 2048:
sparse_mode = 0
else:
sparse_mode = 2
assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
sparse_mode = 2 if max_seqlen_q > 2048 else 0
max_seqlen = min(max_seqlen_q, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
Expand All @@ -114,8 +118,8 @@ def flash_attn_varlen_func(
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=q.shape[0],
next_tockens=0,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand All @@ -134,6 +138,11 @@ def flash_attn_qkvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q = qkv[:, :, 0]
Expand All @@ -143,16 +152,12 @@ def flash_attn_qkvpacked_func(
seqlen_qkv = qkv.shape[1]
head_num = q.shape[-2]

if seqlen_qkv < 2048:
sparse_mode = 0
else:
sparse_mode = 2

seqlen_qkv = min(qkv.shape[1], 2048)
sparse_mode = 2 if seqlen_qkv > 2048 else 0
seqlen = min(seqlen_qkv, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device),
torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
Expand Down Expand Up @@ -187,26 +192,27 @@ def flash_attn_kvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
k = kv[:, :, 0]
v = kv[:, :, 1]

s0 = q.shape[1]
s1 = kv.shape[1]
seqlen_q = q.shape[1]
seqlen_kv = kv.shape[1]
head_num = q.shape[-2]

if s0 == s1 and s0 < 2048 and s1 < 2048:
sparse_mode = 0
else:
sparse_mode = 2

seqlen_q = min(s0, 2048)
seqlen_k = min(s1, 2048)
assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv."
sparse_mode = 2 if seqlen_q > 2048 else 0
seqlen = min(seqlen_q, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
Expand All @@ -222,7 +228,7 @@ def flash_attn_kvpacked_func(
atten_mask=attention_mask,
scale=softmax_scale,
keep_prob=1 - dropout_p,
pre_tockens=seqlen_k,
pre_tockens=seqlen_q,
next_tockens=0,
sparse_mode=sparse_mode,
)[0]
Expand All @@ -242,37 +248,42 @@ def flash_attn_varlen_qkvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q = qkv[:, 0]
k = qkv[:, 1]
v = qkv[:, 2]
n = q.shape[1]
if max_seqlen > 2048:
sparse_mode = 2
else:
sparse_mode = 0
head_num = q.shape[1]

cu_seqlens_q = cu_seqlens[1:].tolist()
cu_seqlens_k = cu_seqlens[1:].tolist()
seqlen = min(max_seqlen, 2048)

sparse_mode = 2 if max_seqlen > 2048 else 0
max_seqlen = min(max_seqlen, 2048)
attention_mask = (
torch.triu(
torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device),
torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)

out = torch_npu.npu_fusion_attention(
q,
k,
v,
n,
head_num,
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=q.shape[0],
next_tockens=0,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand All @@ -296,39 +307,44 @@ def flash_attn_varlen_kvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
k = kv[:, 0]
v = kv[:, 1]
n = q.shape[1]
head_num = q.shape[1]
cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()
seqlen_q = min(max_seqlen_q, 2048)
seqlen_k = min(max_seqlen_k, 2048)

if max_seqlen_q > 2048:
sparse_mode = 2
else:
sparse_mode = 0
assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
sparse_mode = 2 if max_seqlen_q > 2048 else 0
max_seqlen = min(max_seqlen_q, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)

out = torch_npu.npu_fusion_attention(
q,
k,
v,
n,
head_num,
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=q.shape[0],
next_tockens=0,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand Down
91 changes: 63 additions & 28 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024, DeepLink.

import torch
import torch_npu
from einops import rearrange
from einops import repeat
from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding

__all__ = ["ApplyRotaryEmb"]

Expand Down Expand Up @@ -38,38 +38,73 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)
if interleaved:
cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (d 2)")
sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (d 2)")
else:
cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)")
sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)")

rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
ctx.save_for_backward(cat_cos, cat_sin)
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place
if in_place:
x[..., :rotary_dim].copy_(rot)
return x

if interleaved:
x_ro = x[..., :rotary_dim]
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
return x
if rotary_dim < head_dim:
out = torch.empty_like(x)
out[..., :rotary_dim].copy_(out_ro)
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
return out_ro
else:
out = x.detach().clone()
if rotary_dim < head_dim and not in_place:
x_ro = x[..., :rotary_dim]
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
return x
if rotary_dim < head_dim:
out = torch.empty_like(x)
out[..., :rotary_dim].copy_(out_ro)
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
return out
return out_ro

@staticmethod
def backward(ctx, do):
cat_cos, cat_sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cat_cos.shape[-1]
def backward(ctx, grad_out):
cos, sin = ctx.saved_tensors
rotary_dim = cos.shape[-1]
head_dim = grad_out.shape[-1]

dx_out = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)
if ctx.in_place:
do[..., :rotary_dim].copy_(dx_out)
return do, None, None, None, None
if ctx.interleaved:
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = npu_rotary_position_embedding(
grad_out_ro, cos, torch.neg(sin), 1
)
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
return grad_out, None, None, None, None
if rotary_dim < head_dim:
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim].copy_(grad_input_ro)
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
return grad_input_ro, None, None, None, None
else:
dx = do.detach().clone()
dx[..., :rotary_dim].copy_(dx_out)
return dx, None, None, None, None
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = npu_rotary_position_embedding(
grad_out_ro, cos, torch.neg(sin), 0
)
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
return grad_out, None, None, None, None
if rotary_dim < head_dim:
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim].copy_(grad_input_ro)
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
return grad_input_ro, None, None, None, None
3 changes: 1 addition & 2 deletions deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from ._rotary_embedding_npu import ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
Loading
Loading