@@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341
341
b_start_loc, b_seq_len, max_input_len, other_kv_index);
342
342
}
343
343
344
- // void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345
- // const at::Tensor& v, at::Tensor& out,
346
- // const at::Tensor& b_loc,
347
- // const at::Tensor& b_start_loc,
348
- // const at::Tensor& b_seq_len,
349
- // int max_input_len, int other_kv_index) {
350
- // callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351
- // b_seq_len, max_input_len, other_kv_index);
352
- // }
353
-
354
- // void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355
- // const at::Tensor& v, at::Tensor& out,
356
- // const at::Tensor& b_loc,
357
- // const at::Tensor& b_start_loc,
358
- // const at::Tensor& b_seq_len,
359
- // int max_input_len, int other_kv_index) {
360
- // callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361
- // b_seq_len, max_input_len, other_kv_index);
362
- // }
363
-
364
- // void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365
- // const at::Tensor& v, at::Tensor& out,
366
- // const int head, const char* layout,
367
- // const c10::optional<at::Tensor>& padding_mask = {},
368
- // const c10::optional<at::Tensor>& atten_mask = {},
369
- // const OptionalIntArray& actual_seq_lengths = {},
370
- // int64_t num_heads = 1, double scale_value = 1.0,
371
- // const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372
- // callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373
- // actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374
- // }
375
-
376
344
void extPromptFlashAttention (at::Tensor& out, const at::Tensor& q,
377
345
const at::Tensor& k, const at::Tensor& v,
378
346
const at::Tensor& atten_mask,
379
347
const at::IntArrayRef& actual_seq_lengths,
380
- int64_t max_input_len, int64_t num_heads,
348
+ int64_t max_input_len, int64_t num_heads,
381
349
int64_t num_key_value_heads, int64_t dim) {
382
350
callDiopi (diopiPromptFlashAttention, out, q, k, v, atten_mask,
383
- actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
351
+ actual_seq_lengths, max_input_len, num_heads, num_key_value_heads,
352
+ dim);
384
353
}
385
354
386
355
void extContextAttentionInference (const at::Tensor& q, const at::Tensor& k,
@@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
403
372
}
404
373
405
374
void extApplyPenaltyV2 (at::Tensor& logits, const at::Tensor& presence_penalty,
406
- const at::Tensor& frequency_penalty,
407
- const at::Tensor& repetition_penalty,
408
- const at::Tensor& p_token_ids,
409
- const at::Tensor& p_token_counts) {
410
- callDiopi (diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty,
411
- p_token_ids, p_token_counts);
375
+ const at::Tensor& frequency_penalty,
376
+ const at::Tensor& repetition_penalty,
377
+ const at::Tensor& p_token_ids,
378
+ const at::Tensor& p_token_counts) {
379
+ callDiopi (diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty,
380
+ repetition_penalty, p_token_ids, p_token_counts);
412
381
}
413
382
414
- void extPagedAttention (at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415
- const at::IntArrayRef& actual_seq_lengths,
416
- int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
417
- const at::Tensor& block_table,
418
- int64_t block_size) {
419
- callDiopi (diopiPagedAttention, out, q, k, v, actual_seq_lengths,
420
- numHeads, numKeyValueHeads, dim,
421
- block_table, block_size);
383
+ void extPagedAttention (at::Tensor& out, const at::Tensor& q,
384
+ const at::Tensor& k, const at::Tensor& v,
385
+ const c10::optional<at::Tensor>& atten_mask = {},
386
+ const at::IntArrayRef& actual_seq_lengths = {},
387
+ int64_t numHeads = 1 , int64_t numKeyValueHeads = 1 ,
388
+ int64_t dim = 1 ,
389
+ const c10::optional<at::Tensor>& block_table = {},
390
+ int64_t block_size = 1 ) {
391
+ callDiopi (diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
392
+ numHeads, numKeyValueHeads, dim, block_table, block_size);
422
393
}
423
394
424
- void extRotaryEmbeddingV2 (at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
395
+ void extRotaryEmbeddingV2 (at::Tensor& query, at::Tensor& key,
396
+ const at::Tensor& cos, const at::Tensor& sin,
397
+ int64_t dim) {
425
398
callDiopi (diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
426
399
}
427
400
428
401
void extMatmulAllReduce (at::Tensor& out, const at::Tensor& x1,
429
- const at::Tensor& x2, const c10::optional<at::Tensor>& bias,
402
+ const at::Tensor& x2,
403
+ const c10::optional<at::Tensor>& bias,
430
404
const char * group, const char * reduce_op,
431
405
int64_t comm_turn, int64_t stream_mode) {
432
- callDiopi (diopiMatmulAllReduce, out, x1, x2,
433
- bias, group, reduce_op, comm_turn, stream_mode);
406
+ callDiopi (diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op,
407
+ comm_turn, stream_mode);
434
408
}
435
409
436
410
// 判断是否有对应的 diopi 实现:
@@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501
475
m.def (" token_softmax_reducev_inference" , &extTokenSoftmaxReduceVInference,
502
476
" deeplink ext_token_softmax_reducev_inference" );
503
477
}
504
- // if (&diopiTokenDecodeAttentionInference != nullptr) {
505
- // m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506
- // "deeplink token_decode_attention_inference");
507
- // }
508
- // if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509
- // m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510
- // "deeplink token_decode_attention_inference");
511
- // }
512
- // if (&diopiIncreFlashAttention != nullptr) {
513
- // m.def("incre_flash_attention", &extIncreFlashAttention,
514
- // "deeplink incre_flash_attention");
515
- // }
516
478
if (&diopiPromptFlashAttention != nullptr ) {
517
479
m.def (" prompt_flash_attention" , &extPromptFlashAttention,
518
480
" deeplink ext_prompt_flash_attention" );
@@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
540
502
" deeplink ext_paged_attention" );
541
503
}
542
504
if (&diopiRotaryEmbeddingV2 != nullptr ) {
543
- m.def (" rotary_embedding_v2" , &extRotaryEmbeddingV2, " deeplink extRotaryEmbeddingV2" );
505
+ m.def (" rotary_embedding_v2" , &extRotaryEmbeddingV2,
506
+ " deeplink extRotaryEmbeddingV2" );
544
507
}
545
508
if (&diopiMatmulAllReduce != nullptr ) {
546
509
m.def (" matmul_all_reduce" , &extMatmulAllReduce,
547
- " deeplink ext_matmul_all_reduce" ,
548
- py::arg (" out" ), py::arg (" x1" ),
549
- py::arg (" x2" ), py::arg (" bias" ),
550
- py::arg (" group" ), py::arg (" reduce_op" ) = " sum" ,
551
- py::arg (" comm_turn" ) = 0 , py::arg (" stream_mode" ) = 1 );
510
+ " deeplink ext_matmul_all_reduce" , py::arg (" out" ), py::arg (" x1" ),
511
+ py::arg (" x2" ), py::arg (" bias" ), py::arg (" group" ),
512
+ py::arg (" reduce_op" ) = " sum" , py::arg (" comm_turn" ) = 0 ,
513
+ py::arg (" stream_mode" ) = 1 );
552
514
}
553
515
}
554
516
0 commit comments