Skip to content

Commit 0c08f88

Browse files
committed
refactor(camb)
1 parent a7a044b commit 0c08f88

File tree

7 files changed

+10
-25
lines changed

7 files changed

+10
-25
lines changed

lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def forward(self,
1414
key: Tensor,
1515
cos: Tensor,
1616
sin: Tensor,
17-
cu_seqlens: Tensor,
1817
inplace: bool = True):
1918
"""forward."""
2019
if inplace:
@@ -23,7 +22,7 @@ def forward(self,
2322
else:
2423
q_embed = torch.empty_like(query)
2524
k_embed = torch.empty_like(key)
26-
return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cu_seqlens)
25+
return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)
2726

2827

2928
class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):

lmdeploy/pytorch/backends/dlinfer/attention.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
1717
max_kv_seq_len: int = 1
1818
cu_seqlens: Optional[Tensor] = None
1919
is_flash_attn_support_inplace: bool = True
20+
is_mock_q_start_loc: bool = False
2021

2122
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
2223
"""dlinfer attention implementation."""
@@ -76,6 +77,7 @@ def forward(
7677
max_q_seq_len = attn_metadata.max_q_seq_len
7778
max_kv_seq_len = attn_metadata.max_kv_seq_len
7879
cu_seqlens = attn_metadata.cu_seqlens
80+
is_mock_q_start_loc = attn_metadata.is_mock_q_start_loc
7981

8082
# fill kv cache
8183
k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
@@ -85,6 +87,9 @@ def forward(
8587
inplace = inplace if attn_metadata.is_flash_attn_support_inplace \
8688
else False
8789

90+
if is_mock_q_start_loc:
91+
q_start_loc = cu_seqlens
92+
8893
if inplace:
8994
attn_output = query[..., :self.v_head_size]
9095
else:
@@ -107,7 +112,6 @@ def forward(
107112
max_kv_seq_len=max_kv_seq_len,
108113
is_decoding=is_decoding,
109114
block_size=block_size,
110-
cu_seqlens=cu_seqlens,
111115
attn_mask=attn_mask,
112116
is_unpaged_prefill=is_unpaged_prefill,
113117
)

lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def update_step_context(cls, step_context):
9999
max_kv_seq_len=max_kv_seq_len,
100100
cu_seqlens=cu_seqlens,
101101
is_flash_attn_support_inplace=False,
102+
is_mock_q_start_loc=True,
102103
)
103104

104105
step_context.attn_metadata = attn_metadata

lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ def apply_rotary_pos_emb(
1010
sin: Tensor,
1111
q_embed: Tensor = None,
1212
k_embed: Tensor = None,
13-
cu_seqlens=None,
1413
):
1514
query_states = query_states.contiguous()
1615
key_states = key_states.contiguous()
17-
query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cu_seqlens)
16+
query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
1817

1918
if q_embed is None:
2019
q_embed = query_states

lmdeploy/pytorch/kernels/dlinfer/pagedattention.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,17 @@ def prefill_attention(
1717
kv_seq_len: Tensor,
1818
max_q_seq_len: int,
1919
block_size: int,
20-
cu_seqlens: Tensor,
2120
attn_mask: Sequence[Optional[Tensor]],
2221
is_unpaged_prefill: Optional[bool],
2322
):
24-
num_q_heads = query_states.shape[1]
25-
num_kv_heads = value_states.shape[1]
26-
2723
if is_unpaged_prefill:
2824
return ext_ops.prefill_attention(
2925
query_states,
3026
key_states,
3127
value_states,
32-
cu_seqlens,
28+
q_start_loc,
3329
q_seq_len,
3430
max_q_seq_len,
35-
num_q_heads,
36-
num_kv_heads,
3731
attn_mask,
3832
attn_output=attn_output,
3933
)
@@ -56,11 +50,6 @@ def prefill_attention(
5650

5751
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
5852
max_kv_seq_len, block_offsets, block_size):
59-
num_q_heads = q.shape[1]
60-
num_kv_heads = k_cache.shape[1]
61-
q = q.unsqueeze(1)
62-
attn_output = attn_output.unsqueeze(1)
63-
6453
return ext_ops.paged_decode_attention(
6554
q,
6655
k_cache,
@@ -69,8 +58,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6958
block_size,
7059
kv_seq_len,
7160
max_kv_seq_len,
72-
num_q_heads,
73-
num_kv_heads,
7461
attn_output=attn_output,
7562
)
7663

@@ -90,7 +77,6 @@ def paged_attention_fwd(
9077
max_kv_seq_len: int,
9178
is_decoding: bool,
9279
block_size: int,
93-
cu_seqlens: Tensor,
9480
attn_mask: Sequence[Optional[Tensor]] = (),
9581
is_unpaged_prefill: Optional[bool] = None,
9682
):
@@ -108,7 +94,6 @@ def paged_attention_fwd(
10894
kv_seqlens,
10995
max_q_seq_len,
11096
block_size,
111-
cu_seqlens,
11297
attn_mask,
11398
is_unpaged_prefill,
11499
)

lmdeploy/pytorch/models/internlm2.py

-2
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,13 @@ def forward(
7575
query_states, key_states, value_states = self.wqkv.split_qkv(
7676
qkv_states)
7777

78-
cu_seqlens = attn_metadata.cu_seqlens
7978
# apply rotary embedding
8079
cos, sin = rotary_pos_emb
8180
query_states, key_states = self.apply_rotary_pos_emb(
8281
query_states,
8382
key_states,
8483
cos,
8584
sin,
86-
cu_seqlens,
8785
inplace=True,
8886
)
8987

lmdeploy/pytorch/nn/rotary_embedding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def forward(self,
4343
key: Tensor,
4444
cos: Tensor,
4545
sin: Tensor,
46-
cu_seqlens: Tensor,
4746
inplace: bool = True):
4847
"""forward."""
49-
return self.impl.forward(query, key, cos, sin, cu_seqlens, inplace)
48+
return self.impl.forward(query, key, cos, sin, inplace)

0 commit comments

Comments
 (0)