Skip to content

Commit af6dbbe

Browse files
committed
fix
1 parent f45114f commit af6dbbe

File tree

2 files changed

+40
-56
lines changed

2 files changed

+40
-56
lines changed

csrc/extensions.cpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -375,17 +375,12 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
375375

376376
void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
377377
const at::Tensor& k, const at::Tensor& v,
378-
const c10::optional<at::Tensor>& padding_mask = {},
379-
const c10::optional<at::Tensor>& atten_mask = {},
380-
const at::IntArrayRef& actual_seq_lengths = {},
381-
int64_t num_heads = 1, double scale_value = 1.0,
382-
int64_t pre_tokens = 2147473647,
383-
int64_t next_tokens = 0,
384-
const std::string& input_layout = "BSH",
385-
int64_t num_key_value_heads = 0) {
386-
callDiopi(diopiPromptFlashAttention, out, q, k, v, padding_mask, atten_mask,
387-
actual_seq_lengths, num_heads, scale_value, pre_tokens,
388-
next_tokens, input_layout.c_str(), num_key_value_heads);
378+
const at::Tensor& atten_mask,
379+
const at::IntArrayRef& actual_seq_lengths,
380+
int64_t max_input_len, int64_t num_heads,
381+
int64_t num_key_value_heads, int64_t dim) {
382+
callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask,
383+
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
389384
}
390385

391386
void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
@@ -417,24 +412,13 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
417412
}
418413

419414
void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
420-
const c10::optional<at::Tensor>& padding_mask = {},
421-
const c10::optional<at::Tensor>& atten_mask = {},
422-
const at::IntArrayRef& actual_seq_lengths = {},
423-
const c10::optional<at::Tensor>& block_table = {},
424-
int64_t num_heads = 1, int64_t num_key_value_heads = 0,
425-
double scale_value = 1.0, const std::string& input_layout = "BSH",
426-
int64_t block_size = 0, int64_t inner_precise = 1,
427-
const c10::optional<at::Tensor>& antiquant_scale = {}, const c10::optional<at::Tensor>& antiquant_offset = {},
428-
const c10::optional<at::Tensor>& dequant_scale1 = {}, const c10::optional<at::Tensor>& quant_scale1 = {},
429-
const c10::optional<at::Tensor>& dequant_scale2 = {}, const c10::optional<at::Tensor>& quant_scale2 = {},
430-
const c10::optional<at::Tensor>& quant_offset2 = {}, const c10::optional<at::Tensor>& kv_padding_size = {}
431-
) {
432-
callDiopi(diopiPagedAttention, out, q, k, v, padding_mask, atten_mask, actual_seq_lengths,
433-
antiquant_scale, antiquant_offset,
434-
block_table,
435-
dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, kv_padding_size,
436-
num_heads, scale_value, input_layout.c_str(), num_key_value_heads, block_size, inner_precise
437-
);
415+
const at::IntArrayRef& actual_seq_lengths,
416+
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
417+
const at::Tensor& block_table,
418+
int64_t block_size) {
419+
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
420+
numHeads, numKeyValueHeads, dim,
421+
block_table, block_size);
438422
}
439423

440424
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {

deeplink_ext/patch_lightllm.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -71,38 +71,38 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
7171
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
7272
return out
7373

74-
def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75-
batch = b_start_loc.shape[0]
76-
scale = 1 / math.sqrt(dim)
77-
mask_key_str = str(batch) + ":" + str(max_input_len)
78-
if mask_key_str not in mask_cache:
79-
mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
80-
mask = mask.repeat(batch, 1, 1)
81-
mask = torch.logical_not(mask)
82-
mask_cache[mask_key_str] = mask
83-
print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
74+
# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75+
# batch = b_start_loc.shape[0]
76+
# scale = 1 / math.sqrt(dim)
77+
# mask_key_str = str(batch) + ":" + str(max_input_len)
78+
# if mask_key_str not in mask_cache:
79+
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
80+
# mask = mask.repeat(batch, 1, 1)
81+
# mask = torch.logical_not(mask)
82+
# mask_cache[mask_key_str] = mask
83+
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
8484

85-
mask = mask_cache[mask_key_str]
86-
ext.prompt_flash_attention(out, q, k, v,
87-
None, mask, b_seq_len, head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
88-
return out
89-
90-
context_attention_pack.context_attention_fwd = (
91-
# flash_context_attention
92-
fused_context_attention
93-
)
85+
# mask = mask_cache[mask_key_str]
86+
# ext.prompt_flash_attention(out, q, k, v,
87+
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
88+
# return out
9489

90+
# context_attention_pack.context_attention_fwd = (
91+
# # flash_context_attention
92+
# fused_context_attention
93+
# )
94+
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention
9595

9696
def patch_paged_token_attention_inference():
97-
def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98-
ext.paged_attention(out, q, k_cache, v_cache, None, None,
99-
b_seq_len, block_table, q_head_num, kv_head_num,
100-
1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
101-
None, None, None, None, None, None, None, None
102-
)
103-
return out
97+
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98+
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
99+
# b_seq_len, block_table, q_head_num, kv_head_num,
100+
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
101+
# None, None, None, None, None, None, None, None
102+
# )
103+
# return out
104104

105-
token_attention_pack.paged_token_attention = (paged_token_attention)
105+
token_attention_pack.paged_token_attention = ext.paged_attention
106106

107107

108108
def patch_token_attention_inference():

0 commit comments

Comments
 (0)