1
+ use super :: masks:: { ToBitMask , ToBitMaskArray } ;
1
2
use crate :: simd:: {
2
3
cmp:: SimdPartialOrd ,
3
4
intrinsics,
@@ -313,28 +314,39 @@ where
313
314
314
315
#[ must_use]
315
316
#[ inline]
316
- pub fn masked_load_or ( slice : & [ T ] , or : Self ) -> Self {
317
+ pub fn masked_load_or ( slice : & [ T ] , or : Self ) -> Self
318
+ where
319
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
320
+ {
317
321
Self :: masked_load_select ( slice, Mask :: splat ( true ) , or)
318
322
}
319
323
320
324
#[ must_use]
321
325
#[ inline]
322
- pub fn masked_load_select ( slice : & [ T ] , enable : Mask < isize , N > , or : Self ) -> Self {
323
- let ptr = slice. as_ptr ( ) ;
324
- let idxs = Simd :: < usize , N > :: from_slice ( & [
325
- 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 ,
326
- 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
327
- 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 , 61 , 62 , 63 , 64 ,
328
- ] ) ;
329
- let enable: Mask < isize , N > = enable & idxs. simd_lt ( Simd :: splat ( slice. len ( ) ) ) ;
330
- unsafe { Self :: masked_load_select_ptr ( ptr, enable, or) }
326
+ pub fn masked_load_select (
327
+ slice : & [ T ] ,
328
+ mut enable : Mask < <T as SimdElement >:: Mask , N > ,
329
+ or : Self ,
330
+ ) -> Self
331
+ where
332
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
333
+ {
334
+ enable &= {
335
+ let mask = bzhi_u64 ( u64:: MAX , core:: cmp:: min ( N , slice. len ( ) ) as u32 ) ;
336
+ let mask_bytes: [ u8 ; 8 ] = unsafe { core:: mem:: transmute ( mask) } ;
337
+ let mut in_bounds_arr = Mask :: splat ( true ) . to_bitmask_array ( ) ;
338
+ let len = in_bounds_arr. as_ref ( ) . len ( ) ;
339
+ in_bounds_arr. as_mut ( ) . copy_from_slice ( & mask_bytes[ ..len] ) ;
340
+ Mask :: from_bitmask_array ( in_bounds_arr)
341
+ } ;
342
+ unsafe { Self :: masked_load_select_ptr ( slice. as_ptr ( ) , enable, or) }
331
343
}
332
344
333
345
#[ must_use]
334
346
#[ inline]
335
347
pub unsafe fn masked_load_select_unchecked (
336
348
slice : & [ T ] ,
337
- enable : Mask < isize , N > ,
349
+ enable : Mask < < T as SimdElement > :: Mask , N > ,
338
350
or : Self ,
339
351
) -> Self {
340
352
let ptr = slice. as_ptr ( ) ;
@@ -343,7 +355,11 @@ where
343
355
344
356
#[ must_use]
345
357
#[ inline]
346
- pub unsafe fn masked_load_select_ptr ( ptr : * const T , enable : Mask < isize , N > , or : Self ) -> Self {
358
+ pub unsafe fn masked_load_select_ptr (
359
+ ptr : * const T ,
360
+ enable : Mask < <T as SimdElement >:: Mask , N > ,
361
+ or : Self ,
362
+ ) -> Self {
347
363
unsafe { intrinsics:: simd_masked_load ( or, ptr, enable. to_int ( ) ) }
348
364
}
349
365
@@ -526,25 +542,33 @@ where
526
542
}
527
543
528
544
#[ inline]
529
- pub fn masked_store ( self , slice : & mut [ T ] , enable : Mask < isize , N > ) {
530
- let ptr = slice. as_mut_ptr ( ) ;
531
- let idxs = Simd :: < usize , N > :: from_slice ( & [
532
- 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 ,
533
- 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
534
- 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 , 61 , 62 , 63 , 64 ,
535
- ] ) ;
536
- let enable: Mask < isize , N > = enable & idxs. simd_lt ( Simd :: splat ( slice. len ( ) ) ) ;
537
- unsafe { self . masked_store_ptr ( ptr, enable) }
545
+ pub fn masked_store ( self , slice : & mut [ T ] , mut enable : Mask < <T as SimdElement >:: Mask , N > )
546
+ where
547
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
548
+ {
549
+ enable &= {
550
+ let mask = bzhi_u64 ( u64:: MAX , core:: cmp:: min ( N , slice. len ( ) ) as u32 ) ;
551
+ let mask_bytes: [ u8 ; 8 ] = unsafe { core:: mem:: transmute ( mask) } ;
552
+ let mut in_bounds_arr = Mask :: splat ( true ) . to_bitmask_array ( ) ;
553
+ let len = in_bounds_arr. as_ref ( ) . len ( ) ;
554
+ in_bounds_arr. as_mut ( ) . copy_from_slice ( & mask_bytes[ ..len] ) ;
555
+ Mask :: from_bitmask_array ( in_bounds_arr)
556
+ } ;
557
+ unsafe { self . masked_store_ptr ( slice. as_mut_ptr ( ) , enable) }
538
558
}
539
559
540
560
#[ inline]
541
- pub unsafe fn masked_store_unchecked ( self , slice : & mut [ T ] , enable : Mask < isize , N > ) {
561
+ pub unsafe fn masked_store_unchecked (
562
+ self ,
563
+ slice : & mut [ T ] ,
564
+ enable : Mask < <T as SimdElement >:: Mask , N > ,
565
+ ) {
542
566
let ptr = slice. as_mut_ptr ( ) ;
543
567
unsafe { self . masked_store_ptr ( ptr, enable) }
544
568
}
545
569
546
570
#[ inline]
547
- pub unsafe fn masked_store_ptr ( self , ptr : * mut T , enable : Mask < isize , N > ) {
571
+ pub unsafe fn masked_store_ptr ( self , ptr : * mut T , enable : Mask < < T as SimdElement > :: Mask , N > ) {
548
572
unsafe { intrinsics:: simd_masked_store ( self , ptr, enable. to_int ( ) ) }
549
573
}
550
574
@@ -1033,3 +1057,14 @@ where
1033
1057
{
1034
1058
type Mask = isize ;
1035
1059
}
1060
+
1061
+ // This function matches the semantics of the `bzhi` instruction on x86 BMI2
1062
+ // TODO: optimize it further if possible
1063
+ // https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
1064
+ fn bzhi_u64 ( a : u64 , ix : u32 ) -> u64 {
1065
+ if ix > 63 {
1066
+ a
1067
+ } else {
1068
+ a & ( 1u64 << ix) - 1
1069
+ }
1070
+ }
0 commit comments