Skip to content

Commit a7a044b

Browse files
committed
refactor and remove pos_id in apply_rotary_emb
1 parent bd43917 commit a7a044b

File tree

7 files changed

+16
-27
lines changed

7 files changed

+16
-27
lines changed

lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb
66

77
from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl
8-
from .attention import DlinferAttentionMetadata
98

109
class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
1110
"""Apply rotary embedding implementation."""
@@ -15,19 +14,16 @@ def forward(self,
1514
key: Tensor,
1615
cos: Tensor,
1716
sin: Tensor,
18-
attn_metadata: DlinferAttentionMetadata,
17+
cu_seqlens: Tensor,
1918
inplace: bool = True):
2019
"""forward."""
21-
cos_sin_ids = attn_metadata.cos_sin_ids
22-
cu_seqlens = attn_metadata.cu_seqlens
23-
2420
if inplace:
2521
q_embed = None
2622
k_embed = None
2723
else:
2824
q_embed = torch.empty_like(query)
2925
k_embed = torch.empty_like(key)
30-
return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cos_sin_ids, cu_seqlens)
26+
return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cu_seqlens)
3127

3228

3329
class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):

lmdeploy/pytorch/backends/dlinfer/attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010
@dataclass
1111
class DlinferAttentionMetadata(AttentionMetadata):
1212
kv_start_indices: Optional[Tensor] = None
13-
block_size: int = 16
13+
block_size: int = 64
1414
attention_mask: Sequence[Tensor] = tuple()
1515
is_unpaged_prefill: Optional[bool] = None
1616
max_q_seq_len: int = 1
1717
max_kv_seq_len: int = 1
1818
cu_seqlens: Optional[Tensor] = None
19-
cos_sin_ids: Optional[Tensor] = None
20-
19+
is_flash_attn_support_inplace: bool = True
2120

2221
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
2322
"""dlinfer attention implementation."""
@@ -82,6 +81,10 @@ def forward(
8281
k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
8382
kv_start_indices)
8483

84+
if is_unpaged_prefill:
85+
inplace = inplace if attn_metadata.is_flash_attn_support_inplace \
86+
else False
87+
8588
if inplace:
8689
attn_output = query[..., :self.v_head_size]
8790
else:

lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ def update_step_context(cls, step_context):
6262
cu_seqlens = torch.zeros(batch_size+1, dtype=torch.int32, device=device)
6363
cu_seqlens[:-1] = step_context.q_start_loc
6464
cu_seqlens[-1] = step_context.q_seqlens.sum()
65-
cu_seqlens_list = cu_seqlens.tolist()
66-
67-
if not step_context.is_decoding:
68-
cos_sin_ids = step_context.position_ids[0].to(torch.int32)
69-
else:
70-
cos_sin_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
7165

7266
if not step_context.is_decoding:
7367
is_unpaged_prefill = \
@@ -104,7 +98,7 @@ def update_step_context(cls, step_context):
10498
max_q_seq_len=max_q_seq_len,
10599
max_kv_seq_len=max_kv_seq_len,
106100
cu_seqlens=cu_seqlens,
107-
cos_sin_ids=cos_sin_ids,
101+
is_flash_attn_support_inplace=False,
108102
)
109103

110104
step_context.attn_metadata = attn_metadata

lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ def apply_rotary_pos_emb(
1010
sin: Tensor,
1111
q_embed: Tensor = None,
1212
k_embed: Tensor = None,
13-
cos_sin_ids=None,
1413
cu_seqlens=None,
1514
):
1615
query_states = query_states.contiguous()
1716
key_states = key_states.contiguous()
18-
query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cos_sin_ids, cu_seqlens)
17+
query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cu_seqlens)
1918

2019
if q_embed is None:
2120
q_embed = query_states

lmdeploy/pytorch/kernels/dlinfer/pagedattention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def prefill_attention(
2525
num_kv_heads = value_states.shape[1]
2626

2727
if is_unpaged_prefill:
28-
output = torch.empty_like(query_states)
29-
ext_ops.prefill_attention(
28+
return ext_ops.prefill_attention(
3029
query_states,
3130
key_states,
3231
value_states,
@@ -36,10 +35,8 @@ def prefill_attention(
3635
num_q_heads,
3736
num_kv_heads,
3837
attn_mask,
39-
attn_output=output,
38+
attn_output=attn_output,
4039
)
41-
attn_output.copy_(output)
42-
return attn_output
4340
else:
4441
return ext_ops.paged_prefill_attention(
4542
query_states,

lmdeploy/pytorch/models/internlm2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,15 @@ def forward(
7575
query_states, key_states, value_states = self.wqkv.split_qkv(
7676
qkv_states)
7777

78+
cu_seqlens = attn_metadata.cu_seqlens
7879
# apply rotary embedding
7980
cos, sin = rotary_pos_emb
8081
query_states, key_states = self.apply_rotary_pos_emb(
8182
query_states,
8283
key_states,
8384
cos,
8485
sin,
85-
attn_metadata,
86+
cu_seqlens,
8687
inplace=True,
8788
)
8889

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from torch import Tensor, nn
33

44
from ..backends import OpType, get_backend
5-
from ..backends.attention import AttentionMetadata
65
from ..backends.rotary_embedding import (Llama3Parameters,
76
LongRoPEScalingParameters, RopeType,
87
YarnParameters)
@@ -44,7 +43,7 @@ def forward(self,
4443
key: Tensor,
4544
cos: Tensor,
4645
sin: Tensor,
47-
attn_metadata: AttentionMetadata,
46+
cu_seqlens: Tensor,
4847
inplace: bool = True):
4948
"""forward."""
50-
return self.impl.forward(query, key, cos, sin, attn_metadata, inplace)
49+
return self.impl.forward(query, key, cos, sin, cu_seqlens, inplace)

0 commit comments

Comments
 (0)