Skip to content

Commit 81aacd1

Browse files
committed
add attenMask in paged_attention.
1 parent af6dbbe commit 81aacd1

File tree

3 files changed

+89
-110
lines changed

3 files changed

+89
-110
lines changed

csrc/diopi_helper.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@ struct IsOptionalArithmetic<c10::optional<T>> : std::is_arithmetic<T> {};
3535

3636
} // namespace type_traits
3737

38-
inline void checkTensorOnDevice(const at::Tensor& tensor) {
39-
//if (tensor.device().type() == at::DeviceType::CPU) {
40-
// DIPU_LOGE("This op only runs on Device");
41-
// throw std::runtime_error("This op only runs on Device");
42-
//}
43-
}
38+
inline void checkTensorOnDevice(const at::Tensor& tensor) {}
4439

4540
inline void checkTensorOnDevice(const c10::optional<at::Tensor>& tensor) {
4641
if (tensor) {

csrc/extensions.cpp

Lines changed: 32 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341341
b_start_loc, b_seq_len, max_input_len, other_kv_index);
342342
}
343343

344-
// void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345-
// const at::Tensor& v, at::Tensor& out,
346-
// const at::Tensor& b_loc,
347-
// const at::Tensor& b_start_loc,
348-
// const at::Tensor& b_seq_len,
349-
// int max_input_len, int other_kv_index) {
350-
// callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351-
// b_seq_len, max_input_len, other_kv_index);
352-
// }
353-
354-
// void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355-
// const at::Tensor& v, at::Tensor& out,
356-
// const at::Tensor& b_loc,
357-
// const at::Tensor& b_start_loc,
358-
// const at::Tensor& b_seq_len,
359-
// int max_input_len, int other_kv_index) {
360-
// callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361-
// b_seq_len, max_input_len, other_kv_index);
362-
// }
363-
364-
// void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365-
// const at::Tensor& v, at::Tensor& out,
366-
// const int head, const char* layout,
367-
// const c10::optional<at::Tensor>& padding_mask = {},
368-
// const c10::optional<at::Tensor>& atten_mask = {},
369-
// const OptionalIntArray& actual_seq_lengths = {},
370-
// int64_t num_heads = 1, double scale_value = 1.0,
371-
// const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372-
// callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373-
// actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374-
// }
375-
376344
void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
377345
const at::Tensor& k, const at::Tensor& v,
378346
const at::Tensor& atten_mask,
379347
const at::IntArrayRef& actual_seq_lengths,
380-
int64_t max_input_len, int64_t num_heads,
348+
int64_t max_input_len, int64_t num_heads,
381349
int64_t num_key_value_heads, int64_t dim) {
382350
callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask,
383-
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
351+
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads,
352+
dim);
384353
}
385354

386355
void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
@@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
403372
}
404373

405374
void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
406-
const at::Tensor& frequency_penalty,
407-
const at::Tensor& repetition_penalty,
408-
const at::Tensor& p_token_ids,
409-
const at::Tensor& p_token_counts) {
410-
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty,
411-
p_token_ids, p_token_counts);
375+
const at::Tensor& frequency_penalty,
376+
const at::Tensor& repetition_penalty,
377+
const at::Tensor& p_token_ids,
378+
const at::Tensor& p_token_counts) {
379+
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty,
380+
repetition_penalty, p_token_ids, p_token_counts);
412381
}
413382

414-
void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415-
const at::IntArrayRef& actual_seq_lengths,
416-
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
417-
const at::Tensor& block_table,
418-
int64_t block_size) {
419-
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
420-
numHeads, numKeyValueHeads, dim,
421-
block_table, block_size);
383+
void extPagedAttention(at::Tensor& out, const at::Tensor& q,
384+
const at::Tensor& k, const at::Tensor& v,
385+
const c10::optional<at::Tensor>& atten_mask = {},
386+
const at::IntArrayRef& actual_seq_lengths = {},
387+
int64_t numHeads = 1, int64_t numKeyValueHeads = 1,
388+
int64_t dim = 1,
389+
const c10::optional<at::Tensor>& block_table = {},
390+
int64_t block_size = 1) {
391+
callDiopi(diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
392+
numHeads, numKeyValueHeads, dim, block_table, block_size);
422393
}
423394

424-
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
395+
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key,
396+
const at::Tensor& cos, const at::Tensor& sin,
397+
int64_t dim) {
425398
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
426399
}
427400

428401
void extMatmulAllReduce(at::Tensor& out, const at::Tensor& x1,
429-
const at::Tensor& x2, const c10::optional<at::Tensor>& bias,
402+
const at::Tensor& x2,
403+
const c10::optional<at::Tensor>& bias,
430404
const char* group, const char* reduce_op,
431405
int64_t comm_turn, int64_t stream_mode) {
432-
callDiopi(diopiMatmulAllReduce, out, x1, x2,
433-
bias, group, reduce_op, comm_turn, stream_mode);
406+
callDiopi(diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op,
407+
comm_turn, stream_mode);
434408
}
435409

436410
// 判断是否有对应的 diopi 实现:
@@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501475
m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference,
502476
"deeplink ext_token_softmax_reducev_inference");
503477
}
504-
// if (&diopiTokenDecodeAttentionInference != nullptr) {
505-
// m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506-
// "deeplink token_decode_attention_inference");
507-
// }
508-
// if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509-
// m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510-
// "deeplink token_decode_attention_inference");
511-
// }
512-
// if (&diopiIncreFlashAttention != nullptr) {
513-
// m.def("incre_flash_attention", &extIncreFlashAttention,
514-
// "deeplink incre_flash_attention");
515-
// }
516478
if (&diopiPromptFlashAttention != nullptr) {
517479
m.def("prompt_flash_attention", &extPromptFlashAttention,
518480
"deeplink ext_prompt_flash_attention");
@@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
540502
"deeplink ext_paged_attention");
541503
}
542504
if (&diopiRotaryEmbeddingV2 != nullptr) {
543-
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, "deeplink extRotaryEmbeddingV2");
505+
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2,
506+
"deeplink extRotaryEmbeddingV2");
544507
}
545508
if (&diopiMatmulAllReduce != nullptr) {
546509
m.def("matmul_all_reduce", &extMatmulAllReduce,
547-
"deeplink ext_matmul_all_reduce",
548-
py::arg("out"), py::arg("x1"),
549-
py::arg("x2"), py::arg("bias"),
550-
py::arg("group"), py::arg("reduce_op") = "sum",
551-
py::arg("comm_turn") = 0, py::arg("stream_mode") = 1);
510+
"deeplink ext_matmul_all_reduce", py::arg("out"), py::arg("x1"),
511+
py::arg("x2"), py::arg("bias"), py::arg("group"),
512+
py::arg("reduce_op") = "sum", py::arg("comm_turn") = 0,
513+
py::arg("stream_mode") = 1);
552514
}
553515
}
554516

deeplink_ext/patch_lightllm.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def patch_apply_penalty():
4646
apply_penalty_pack.apply_penalty_v2 = ext.apply_penalty_v2
4747

4848
def patch_context_attention_inference():
49-
def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
49+
def flash_context_attention(
50+
q, k, v, out, b_start_loc, b_seq_len, max_input_len
51+
):
5052
batch, head, dim = b_start_loc.shape[0], q.shape[1], q.shape[2]
5153
numKeyValueHeads = k.shape[1]
5254
assert k.shape[1] == v.shape[1]
@@ -62,52 +64,72 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
6264

6365
single_out = out[start:end, :].view(1, single_seq_len, -1)
6466
if single_seq_len not in mask_cache:
65-
mask = torch.tril(torch.ones(single_seq_len, single_seq_len, dtype=torch.bool), diagonal=0).cuda()
67+
mask = torch.tril(
68+
torch.ones(
69+
single_seq_len, single_seq_len, dtype=torch.bool
70+
),
71+
diagonal=0,
72+
).cuda()
6673
mask = mask.repeat(1, 1, 1)
6774
mask = torch.logical_not(mask)
6875
mask_cache[single_seq_len] = mask
69-
print(f"cache mask in context attention, seqLen:{single_seq_len}")
76+
print(
77+
f"cache mask in context attention, seqLen:{single_seq_len}"
78+
)
7079
mask = mask_cache[single_seq_len]
71-
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
80+
ext.prompt_flash_attention(
81+
single_out,
82+
single_q,
83+
single_k,
84+
single_v,
85+
None,
86+
mask,
87+
[],
88+
head,
89+
scale,
90+
2147473647,
91+
0,
92+
"BSH",
93+
numKeyValueHeads,
94+
)
7295
return out
7396

74-
# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75-
# batch = b_start_loc.shape[0]
76-
# scale = 1 / math.sqrt(dim)
77-
# mask_key_str = str(batch) + ":" + str(max_input_len)
78-
# if mask_key_str not in mask_cache:
79-
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
80-
# mask = mask.repeat(batch, 1, 1)
81-
# mask = torch.logical_not(mask)
82-
# mask_cache[mask_key_str] = mask
83-
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
84-
85-
# mask = mask_cache[mask_key_str]
86-
# ext.prompt_flash_attention(out, q, k, v,
87-
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
88-
# return out
89-
90-
# context_attention_pack.context_attention_fwd = (
91-
# # flash_context_attention
92-
# fused_context_attention
93-
# )
9497
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention
9598

9699
def patch_paged_token_attention_inference():
97-
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98-
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
99-
# b_seq_len, block_table, q_head_num, kv_head_num,
100-
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
101-
# None, None, None, None, None, None, None, None
102-
# )
103-
# return out
104-
105-
token_attention_pack.paged_token_attention = ext.paged_attention
100+
def paged_token_attention(
101+
out,
102+
q,
103+
k_cache,
104+
v_cache,
105+
b_seq_len,
106+
q_head_num,
107+
kv_head_num,
108+
head_dim,
109+
block_table,
110+
block_size,
111+
):
112+
ext.paged_attention(
113+
out,
114+
q,
115+
k_cache,
116+
v_cache,
117+
None,
118+
b_seq_len,
119+
q_head_num,
120+
kv_head_num,
121+
head_dim,
122+
block_table,
123+
block_size,
124+
)
106125

126+
token_attention_pack.paged_token_attention = paged_token_attention
107127

108128
def patch_token_attention_inference():
109129
token_attention_pack.token_att_fwd = ext.token_attention_inference
110-
token_attention_pack.token_decode_attention_fwd = ext.token_decode_attention_inference_batch_one#ext.token_decode_attention_inference
130+
token_attention_pack.token_decode_attention_fwd = (
131+
ext.token_decode_attention_inference_batch_one
132+
) # ext.token_decode_attention_inference
111133

112134
def patch_token_softmax_reducev_inference():
113135
token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = (

0 commit comments

Comments
 (0)