@@ -541,8 +541,11 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
541
541
542
542
/* argsort methods for 32-bit and 64-bit dtypes */
543
543
template <typename T>
544
- X86_SIMD_SORT_INLINE void
545
- avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
544
+ X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
545
+ arrsize_t *arg,
546
+ arrsize_t arrsize,
547
+ bool hasnan = false ,
548
+ bool descending = false )
546
549
{
547
550
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548
551
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
@@ -558,29 +561,37 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558
561
if constexpr (std::is_floating_point_v<T>) {
559
562
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560
563
std_argsort_withnan (arr, arg, 0 , arrsize);
564
+
565
+ if (descending) { std::reverse (arg, arg + arrsize); }
566
+
561
567
return ;
562
568
}
563
569
}
564
570
UNUSED (hasnan);
565
571
argsort_64bit_<vectype, argtype>(
566
572
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
573
+
574
+ if (descending) { std::reverse (arg, arg + arrsize); }
567
575
}
568
576
}
569
577
570
578
template <typename T>
571
- X86_SIMD_SORT_INLINE std::vector<arrsize_t >
572
- avx512_argsort ( T *arr, arrsize_t arrsize, bool hasnan = false )
579
+ X86_SIMD_SORT_INLINE std::vector<arrsize_t > avx512_argsort (
580
+ T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
573
581
{
574
582
std::vector<arrsize_t > indices (arrsize);
575
583
std::iota (indices.begin (), indices.end (), 0 );
576
- avx512_argsort<T>(arr, indices.data (), arrsize, hasnan);
584
+ avx512_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
577
585
return indices;
578
586
}
579
587
580
588
/* argsort methods for 32-bit and 64-bit dtypes */
581
589
template <typename T>
582
- X86_SIMD_SORT_INLINE void
583
- avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
590
+ X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
591
+ arrsize_t *arg,
592
+ arrsize_t arrsize,
593
+ bool hasnan = false ,
594
+ bool descending = false )
584
595
{
585
596
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586
597
avx2_half_vector<T>,
@@ -594,22 +605,27 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594
605
if constexpr (std::is_floating_point_v<T>) {
595
606
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596
607
std_argsort_withnan (arr, arg, 0 , arrsize);
608
+
609
+ if (descending) { std::reverse (arg, arg + arrsize); }
610
+
597
611
return ;
598
612
}
599
613
}
600
614
UNUSED (hasnan);
601
615
argsort_64bit_<vectype, argtype>(
602
616
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
617
+
618
+ if (descending) { std::reverse (arg, arg + arrsize); }
603
619
}
604
620
}
605
621
606
622
template <typename T>
607
- X86_SIMD_SORT_INLINE std::vector<arrsize_t >
608
- avx2_argsort ( T *arr, arrsize_t arrsize, bool hasnan = false )
623
+ X86_SIMD_SORT_INLINE std::vector<arrsize_t > avx2_argsort (
624
+ T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
609
625
{
610
626
std::vector<arrsize_t > indices (arrsize);
611
627
std::iota (indices.begin (), indices.end (), 0 );
612
- avx2_argsort<T>(arr, indices.data (), arrsize, hasnan);
628
+ avx2_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
613
629
return indices;
614
630
}
615
631
0 commit comments