Skip to content

feat: fix bug of rope for npu #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 83 additions & 29 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,44 @@

import torch
import torch_npu
from einops import rearrange
from einops import rearrange, repeat

__all__ = ["ApplyRotaryEmb"]


def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
"""
Expand Down Expand Up @@ -38,38 +71,59 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
ctx.save_for_backward(cat_cos, cat_sin)
if interleaved:
cos = cos[:seqlen]
sin = sin[:seqlen]
else:
# "s d -> 1 s 1 d"
cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place
if in_place:
x[..., :rotary_dim].copy_(rot)
return x
if interleaved:
out = apply_rotary_emb_torch(x, cos, sin, interleaved)
if in_place:
x.copy_(out)
return x
else:
return out
else:
out = x.detach().clone()
if rotary_dim < head_dim and not in_place:
x_ro = x[..., :rotary_dim]
out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
return x
if rotary_dim < head_dim:
out = torch.empty_like(x)
out[..., :rotary_dim].copy_(out_ro)
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
return out
return out_ro

@staticmethod
def backward(ctx, do):
cat_cos, cat_sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cat_cos.shape[-1]

dx_out = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)
if ctx.in_place:
do[..., :rotary_dim].copy_(dx_out)
return do, None, None, None, None
def backward(ctx, grad_out):
cos, sin = ctx.saved_tensors
rotary_dim = cos.shape[-1]
head_dim = grad_out.shape[-1]
if ctx.interleaved:
grad_input = apply_rotary_emb_torch(
grad_out, cos, torch.neg(sin), ctx.interleaved
)
if ctx.in_place:
grad_out.copy_(grad_input)
return grad_out, None, None, None, None
else:
return grad_input, None, None, None, None
else:
dx = do.detach().clone()
dx[..., :rotary_dim].copy_(dx_out)
return dx, None, None, None, None
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin))
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
return grad_out, None, None, None, None
if rotary_dim < head_dim:
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim].copy_(grad_input_ro)
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
return grad_input_ro, None, None, None, None
3 changes: 1 addition & 2 deletions deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from ._rotary_embedding_npu import ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
67 changes: 34 additions & 33 deletions tests/internevo/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,41 @@

def test_ApplyRotaryEmb():
input_dtype_list = [torch.float16, torch.bfloat16]
interleaved = False
in_place_options = [False, True]
interleaved_options = [False, True]
for input_dtype in input_dtype_list:
for in_place in in_place_options:
input_ref = torch.randn(
1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True
)
input_ext = input_ref.clone().detach().requires_grad_()
cos = torch.randn(64, 32, dtype=input_dtype, device="cuda")
sin = torch.randn(64, 32, dtype=input_dtype, device="cuda")
for interleaved in interleaved_options:
input_ref = torch.randn(
1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True
)
input_ext = input_ref.clone().detach().requires_grad_()
cos = torch.randn(64, 32, dtype=input_dtype, device="cuda")
sin = torch.randn(64, 32, dtype=input_dtype, device="cuda")

output_ref, grad_ref = call_autograd_func(
ApplyRotaryEmbTorch,
"cuda",
input_dtype,
input_ref,
cos,
sin,
interleaved,
in_place,
)
output_ext, grad_ext = call_autograd_func(
ApplyRotaryEmb,
"cuda",
input_dtype,
input_ext,
cos,
sin,
interleaved,
in_place,
)
assert allclose(
output_ref, output_ext, rtol=1e-2, atol=5e-2
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!"
assert allclose(
grad_ref, grad_ext
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"
output_ref, grad_ref = call_autograd_func(
ApplyRotaryEmbTorch,
"cuda",
input_dtype,
input_ref,
cos,
sin,
interleaved,
in_place,
)
output_ext, grad_ext = call_autograd_func(
ApplyRotaryEmb,
"cuda",
input_dtype,
input_ext,
cos,
sin,
interleaved,
in_place,
)
assert allclose(
output_ref, output_ext, rtol=1e-2, atol=5e-2
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!"
assert allclose(
grad_ref, grad_ext
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"
Loading