Skip to content

Commit 66a5748

Browse files
committed
Optimize bounds checking
1 parent f182fa7 commit 66a5748

File tree

1 file changed

+58
-23
lines changed

1 file changed

+58
-23
lines changed

crates/core_simd/src/vector.rs

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::masks::{ToBitMask, ToBitMaskArray};
12
use crate::simd::{
23
cmp::SimdPartialOrd,
34
intrinsics,
@@ -313,28 +314,39 @@ where
313314

314315
#[must_use]
315316
#[inline]
316-
pub fn masked_load_or(slice: &[T], or: Self) -> Self {
317+
pub fn masked_load_or(slice: &[T], or: Self) -> Self
318+
where
319+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
320+
{
317321
Self::masked_load_select(slice, Mask::splat(true), or)
318322
}
319323

320324
#[must_use]
321325
#[inline]
322-
pub fn masked_load_select(slice: &[T], enable: Mask<isize, N>, or: Self) -> Self {
323-
let ptr = slice.as_ptr();
324-
let idxs = Simd::<usize, N>::from_slice(&[
325-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
326-
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
327-
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
328-
]);
329-
let enable: Mask<isize, N> = enable & idxs.simd_lt(Simd::splat(slice.len()));
330-
unsafe { Self::masked_load_select_ptr(ptr, enable, or) }
326+
pub fn masked_load_select(
327+
slice: &[T],
328+
mut enable: Mask<<T as SimdElement>::Mask, N>,
329+
or: Self,
330+
) -> Self
331+
where
332+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
333+
{
334+
enable &= {
335+
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
336+
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
337+
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
338+
let len = in_bounds_arr.as_ref().len();
339+
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
340+
Mask::from_bitmask_array(in_bounds_arr)
341+
};
342+
unsafe { Self::masked_load_select_ptr(slice.as_ptr(), enable, or) }
331343
}
332344

333345
#[must_use]
334346
#[inline]
335347
pub unsafe fn masked_load_select_unchecked(
336348
slice: &[T],
337-
enable: Mask<isize, N>,
349+
enable: Mask<<T as SimdElement>::Mask, N>,
338350
or: Self,
339351
) -> Self {
340352
let ptr = slice.as_ptr();
@@ -343,7 +355,11 @@ where
343355

344356
#[must_use]
345357
#[inline]
346-
pub unsafe fn masked_load_select_ptr(ptr: *const T, enable: Mask<isize, N>, or: Self) -> Self {
358+
pub unsafe fn masked_load_select_ptr(
359+
ptr: *const T,
360+
enable: Mask<<T as SimdElement>::Mask, N>,
361+
or: Self,
362+
) -> Self {
347363
unsafe { intrinsics::simd_masked_load(or, ptr, enable.to_int()) }
348364
}
349365

@@ -526,25 +542,33 @@ where
526542
}
527543

528544
#[inline]
529-
pub fn masked_store(self, slice: &mut [T], enable: Mask<isize, N>) {
530-
let ptr = slice.as_mut_ptr();
531-
let idxs = Simd::<usize, N>::from_slice(&[
532-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
533-
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
534-
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
535-
]);
536-
let enable: Mask<isize, N> = enable & idxs.simd_lt(Simd::splat(slice.len()));
537-
unsafe { self.masked_store_ptr(ptr, enable) }
545+
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>)
546+
where
547+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
548+
{
549+
enable &= {
550+
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
551+
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
552+
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
553+
let len = in_bounds_arr.as_ref().len();
554+
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
555+
Mask::from_bitmask_array(in_bounds_arr)
556+
};
557+
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
538558
}
539559

540560
#[inline]
541-
pub unsafe fn masked_store_unchecked(self, slice: &mut [T], enable: Mask<isize, N>) {
561+
pub unsafe fn masked_store_unchecked(
562+
self,
563+
slice: &mut [T],
564+
enable: Mask<<T as SimdElement>::Mask, N>,
565+
) {
542566
let ptr = slice.as_mut_ptr();
543567
unsafe { self.masked_store_ptr(ptr, enable) }
544568
}
545569

546570
#[inline]
547-
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<isize, N>) {
571+
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
548572
unsafe { intrinsics::simd_masked_store(self, ptr, enable.to_int()) }
549573
}
550574

@@ -1033,3 +1057,14 @@ where
10331057
{
10341058
type Mask = isize;
10351059
}
1060+
1061+
// This function matches the semantics of the `bzhi` instruction on x86 BMI2
1062+
// TODO: optimize it further if possible
1063+
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
1064+
fn bzhi_u64(a: u64, ix: u32) -> u64 {
1065+
if ix > 63 {
1066+
a
1067+
} else {
1068+
a & (1u64 << ix) - 1
1069+
}
1070+
}

0 commit comments

Comments
 (0)