@@ -542,7 +542,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
542
542
/* argsort methods for 32-bit and 64-bit dtypes */
543
543
template <typename T>
544
544
X86_SIMD_SORT_INLINE void
545
- avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
545
+ avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false , bool descending = false )
546
546
{
547
547
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548
548
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
@@ -558,29 +558,38 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558
558
if constexpr (std::is_floating_point_v<T>) {
559
559
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560
560
std_argsort_withnan (arr, arg, 0 , arrsize);
561
+
562
+ if (descending){
563
+ std::reverse (arg, arg + arrsize);
564
+ }
565
+
561
566
return ;
562
567
}
563
568
}
564
569
UNUSED (hasnan);
565
570
argsort_64bit_<vectype, argtype>(
566
571
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
572
+
573
+ if (descending){
574
+ std::reverse (arg, arg + arrsize);
575
+ }
567
576
}
568
577
}
569
578
570
579
template <typename T>
571
580
X86_SIMD_SORT_INLINE std::vector<arrsize_t >
572
- avx512_argsort (T *arr, arrsize_t arrsize, bool hasnan = false )
581
+ avx512_argsort (T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
573
582
{
574
583
std::vector<arrsize_t > indices (arrsize);
575
584
std::iota (indices.begin (), indices.end (), 0 );
576
- avx512_argsort<T>(arr, indices.data (), arrsize, hasnan);
585
+ avx512_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
577
586
return indices;
578
587
}
579
588
580
589
/* argsort methods for 32-bit and 64-bit dtypes */
581
590
template <typename T>
582
591
X86_SIMD_SORT_INLINE void
583
- avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
592
+ avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false , bool descending = false )
584
593
{
585
594
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586
595
avx2_half_vector<T>,
@@ -594,22 +603,31 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594
603
if constexpr (std::is_floating_point_v<T>) {
595
604
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596
605
std_argsort_withnan (arr, arg, 0 , arrsize);
606
+
607
+ if (descending){
608
+ std::reverse (arg, arg + arrsize);
609
+ }
610
+
597
611
return ;
598
612
}
599
613
}
600
614
UNUSED (hasnan);
601
615
argsort_64bit_<vectype, argtype>(
602
616
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
617
+
618
+ if (descending){
619
+ std::reverse (arg, arg + arrsize);
620
+ }
603
621
}
604
622
}
605
623
606
624
template <typename T>
607
625
X86_SIMD_SORT_INLINE std::vector<arrsize_t >
608
- avx2_argsort (T *arr, arrsize_t arrsize, bool hasnan = false )
626
+ avx2_argsort (T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
609
627
{
610
628
std::vector<arrsize_t > indices (arrsize);
611
629
std::iota (indices.begin (), indices.end (), 0 );
612
- avx2_argsort<T>(arr, indices.data (), arrsize, hasnan);
630
+ avx2_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
613
631
return indices;
614
632
}
615
633
@@ -631,7 +649,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
631
649
ymm_vector<arrsize_t >,
632
650
zmm_vector<arrsize_t >>::type;
633
651
634
- if (arrsize > 1 ) {
652
+ if (arrsize > 1 ) {
635
653
if constexpr (std::is_floating_point_v<T>) {
636
654
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
637
655
std_argselect_withnan (arr, arg, k, 0 , arrsize);
0 commit comments