Skip to content

[feat] Support swa for npu #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 114 additions & 107 deletions deeplink_ext/internevo_ops/_flash_attention_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,43 @@
"flash_attn_varlen_kvpacked_func",
]

# construct a global attention mask for npu
_GLOBAL_ATTN_MASK = None


def get_attention_mask(seqlen, causal, window_size):
global _GLOBAL_ATTN_MASK

if _GLOBAL_ATTN_MASK is not None:
return _GLOBAL_ATTN_MASK

# causal attention
if causal:
if seqlen > 2048:
_GLOBAL_ATTN_MASK = torch.triu(
torch.ones([2048, 2048], dtype=bool, device=torch.npu.current_device()),
diagonal=1,
)
else:
_GLOBAL_ATTN_MASK = torch.triu(
torch.ones(
[seqlen, seqlen], dtype=bool, device=torch.npu.current_device()
),
diagonal=1,
)

# sliding window attention
if window_size[0] >= 0 or window_size[1] >= 0:
_GLOBAL_ATTN_MASK = torch.tril(
torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()),
diagonal=-((seqlen - 1 if window_size[0] < 0 else window_size[0]) + 1),
) + torch.triu(
torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()),
diagonal=(seqlen - 1 if window_size[1] < 0 else window_size[1]) + 1,
)

return _GLOBAL_ATTN_MASK


def flash_attn_func(
q,
Expand All @@ -32,22 +69,15 @@ def flash_attn_func(
seqlen_k = k.shape[1]
head_num = q.shape[-2]

if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048:
sparse_mode = 0
else:
sparse_mode = 2
assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k."
attention_mask = get_attention_mask(seqlen_q, causal, window_size)
sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4

seqlen_q = min(seqlen_q, 2048)
seqlen_k = min(seqlen_k, 2048)

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)
pre_tokens = seqlen_q - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0]
next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
Expand All @@ -58,8 +88,8 @@ def flash_attn_func(
atten_mask=attention_mask,
scale=softmax_scale,
keep_prob=1 - dropout_p,
pre_tockens=seqlen_q,
next_tockens=0,
pre_tockens=pre_tokens,
next_tockens=next_tokens,
sparse_mode=sparse_mode,
)[0]

Expand Down Expand Up @@ -89,22 +119,18 @@ def flash_attn_varlen_func(

cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()
seqlen_q = min(max_seqlen_q, 2048)
seqlen_k = min(max_seqlen_k, 2048)

if max_seqlen_q < 2048:
sparse_mode = 0
else:
sparse_mode = 2

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)

assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
attention_mask = get_attention_mask(max_seqlen_q, causal, window_size)
sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4

pre_tokens = max_seqlen_q - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0]
next_tokens = max_seqlen_q - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
Expand All @@ -114,8 +140,8 @@ def flash_attn_varlen_func(
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=pre_tokens,
next_tockens=next_tokens,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand Down Expand Up @@ -143,21 +169,14 @@ def flash_attn_qkvpacked_func(
seqlen_qkv = qkv.shape[1]
head_num = q.shape[-2]

if seqlen_qkv < 2048:
sparse_mode = 0
else:
sparse_mode = 2

seqlen_qkv = min(qkv.shape[1], 2048)
attention_mask = get_attention_mask(seqlen_qkv, causal, window_size)
sparse_mode = 0 if attention_mask is None or seqlen_qkv <= 2048 else 4

attention_mask = (
torch.triu(
torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)
pre_tokens = seqlen_qkv - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = seqlen_qkv - 1 if window_size[0] < 0 else window_size[0]
next_tokens = seqlen_qkv - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
Expand All @@ -168,8 +187,8 @@ def flash_attn_qkvpacked_func(
atten_mask=attention_mask,
scale=softmax_scale,
keep_prob=1 - dropout_p,
pre_tockens=seqlen_qkv,
next_tockens=0,
pre_tockens=pre_tokens,
next_tockens=next_tokens,
sparse_mode=sparse_mode,
)[0]

Expand All @@ -192,26 +211,19 @@ def flash_attn_kvpacked_func(
k = kv[:, :, 0]
v = kv[:, :, 1]

s0 = q.shape[1]
s1 = kv.shape[1]
seqlen_q = q.shape[1]
seqlen_kv = kv.shape[1]
head_num = q.shape[-2]

if s0 == s1 and s0 < 2048 and s1 < 2048:
sparse_mode = 0
else:
sparse_mode = 2

seqlen_q = min(s0, 2048)
seqlen_k = min(s1, 2048)
assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv."
attention_mask = get_attention_mask(seqlen_q, causal, window_size)
sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)
pre_tokens = seqlen_q - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0]
next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
Expand All @@ -222,8 +234,8 @@ def flash_attn_kvpacked_func(
atten_mask=attention_mask,
scale=softmax_scale,
keep_prob=1 - dropout_p,
pre_tockens=seqlen_k,
next_tockens=0,
pre_tockens=pre_tokens,
next_tockens=next_tokens,
sparse_mode=sparse_mode,
)[0]

Expand All @@ -247,32 +259,30 @@ def flash_attn_varlen_qkvpacked_func(
q = qkv[:, 0]
k = qkv[:, 1]
v = qkv[:, 2]
n = q.shape[1]
if max_seqlen > 2048:
sparse_mode = 2
else:
sparse_mode = 0
head_num = q.shape[1]

cu_seqlens_q = cu_seqlens[1:].tolist()
cu_seqlens_k = cu_seqlens[1:].tolist()
seqlen = min(max_seqlen, 2048)
attention_mask = (
torch.triu(
torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)

attention_mask = get_attention_mask(max_seqlen, causal, window_size)
sparse_mode = 0 if attention_mask is None or max_seqlen <= 2048 else 4

pre_tokens = max_seqlen - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = max_seqlen - 1 if window_size[0] < 0 else window_size[0]
next_tokens = max_seqlen - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
k,
v,
n,
head_num,
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=pre_tokens,
next_tockens=next_tokens,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand Down Expand Up @@ -300,35 +310,32 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale = q.shape[-1] ** (-0.5)
k = kv[:, 0]
v = kv[:, 1]
n = q.shape[1]
head_num = q.shape[1]
cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()
seqlen_q = min(max_seqlen_q, 2048)
seqlen_k = min(max_seqlen_k, 2048)

if max_seqlen_q > 2048:
sparse_mode = 2
else:
sparse_mode = 0

attention_mask = (
torch.triu(
torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device),
diagonal=1,
)
if causal
else None
)

assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
attention_mask = get_attention_mask(max_seqlen_q, causal, window_size)
sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4

pre_tokens = max_seqlen_q - 1
next_tokens = 0
if window_size[0] >= 0 or window_size[1] >= 0:
pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0]
next_tokens = max_seqlen_k - 1 if window_size[1] < 0 else window_size[1]

out = torch_npu.npu_fusion_attention(
q,
k,
v,
n,
head_num,
"TND",
atten_mask=attention_mask,
scale=softmax_scale,
pre_tockens=q.shape[0], # seq_len
next_tockens=0, # 0
pre_tockens=pre_tokens,
next_tockens=next_tokens,
keep_prob=1 - dropout_p,
sparse_mode=sparse_mode,
actual_seq_qlen=cu_seqlens_q,
Expand Down
Loading
Loading