@@ -341,38 +341,6 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341
341
b_start_loc, b_seq_len, max_input_len, other_kv_index);
342
342
}
343
343
344
- // void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345
- // const at::Tensor& v, at::Tensor& out,
346
- // const at::Tensor& b_loc,
347
- // const at::Tensor& b_start_loc,
348
- // const at::Tensor& b_seq_len,
349
- // int max_input_len, int other_kv_index) {
350
- // callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351
- // b_seq_len, max_input_len, other_kv_index);
352
- // }
353
-
354
- // void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355
- // const at::Tensor& v, at::Tensor& out,
356
- // const at::Tensor& b_loc,
357
- // const at::Tensor& b_start_loc,
358
- // const at::Tensor& b_seq_len,
359
- // int max_input_len, int other_kv_index) {
360
- // callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361
- // b_seq_len, max_input_len, other_kv_index);
362
- // }
363
-
364
- // void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365
- // const at::Tensor& v, at::Tensor& out,
366
- // const int head, const char* layout,
367
- // const c10::optional<at::Tensor>& padding_mask = {},
368
- // const c10::optional<at::Tensor>& atten_mask = {},
369
- // const OptionalIntArray& actual_seq_lengths = {},
370
- // int64_t num_heads = 1, double scale_value = 1.0,
371
- // const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372
- // callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373
- // actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374
- // }
375
-
376
344
void extPromptFlashAttention (at::Tensor& out, const at::Tensor& q,
377
345
const at::Tensor& k, const at::Tensor& v,
378
346
const at::Tensor& atten_mask,
@@ -412,11 +380,11 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
412
380
}
413
381
414
382
void extPagedAttention (at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415
- const at::IntArrayRef& actual_seq_lengths ,
416
- int64_t numHeads, int64_t numKeyValueHeads, int64_t dim ,
417
- const at::Tensor& block_table ,
418
- int64_t block_size) {
419
- callDiopi (diopiPagedAttention, out, q, k, v, actual_seq_lengths,
383
+ const c10::optional< at::Tensor>& atten_mask = {} ,
384
+ const at::IntArrayRef& actual_seq_lengths = {} ,
385
+ int64_t numHeads = 1 , int64_t numKeyValueHeads = 1 , int64_t dim = 1 ,
386
+ const c10::optional<at::Tensor>& block_table = {}, int64_t block_size = 1 ) {
387
+ callDiopi (diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
420
388
numHeads, numKeyValueHeads, dim,
421
389
block_table, block_size);
422
390
}
@@ -501,18 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501
469
m.def (" token_softmax_reducev_inference" , &extTokenSoftmaxReduceVInference,
502
470
" deeplink ext_token_softmax_reducev_inference" );
503
471
}
504
- // if (&diopiTokenDecodeAttentionInference != nullptr) {
505
- // m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506
- // "deeplink token_decode_attention_inference");
507
- // }
508
- // if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509
- // m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510
- // "deeplink token_decode_attention_inference");
511
- // }
512
- // if (&diopiIncreFlashAttention != nullptr) {
513
- // m.def("incre_flash_attention", &extIncreFlashAttention,
514
- // "deeplink incre_flash_attention");
515
- // }
516
472
if (&diopiPromptFlashAttention != nullptr ) {
517
473
m.def (" prompt_flash_attention" , &extPromptFlashAttention,
518
474
" deeplink ext_prompt_flash_attention" );
0 commit comments