Skip to content

Commit ae23929

Browse files
committed
dedup SIMD sample_single_inclusive_bitmask
1 parent 03b1b99 commit ae23929

File tree

1 file changed

+45
-90
lines changed

1 file changed

+45
-90
lines changed

src/distributions/uniform/uniform_int.rs

Lines changed: 45 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,51 @@ macro_rules! uniform_simd_int_impl {
714714
let range: $unsigned = ((high - low) + 1).cast();
715715
(range, low)
716716
}
717+
718+
///
719+
#[inline(always)]
720+
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
721+
low_b: B1, high_b: B2, rng: &mut R,
722+
) -> $ty
723+
where
724+
B1: SampleBorrow<$ty> + Sized,
725+
B2: SampleBorrow<$ty> + Sized,
726+
{
727+
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
728+
let is_full_range = range.eq($unsigned::splat(0));
729+
730+
// generate bitmask
731+
range -= 1;
732+
let mut mask = range | 1;
733+
734+
mask |= mask >> 1;
735+
mask |= mask >> 2;
736+
mask |= mask >> 4;
737+
738+
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
739+
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
740+
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
741+
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
742+
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
743+
744+
let mut v: $unsigned = rng.gen();
745+
loop {
746+
let masked = v & mask;
747+
let accept = masked.le(range);
748+
if accept.all() {
749+
let masked: $ty = masked.cast();
750+
// wrapping addition
751+
let result = low + masked;
752+
// `select` here compiles to a blend operation
753+
// When `range.eq(0).none()` the compare and blend
754+
// operations are avoided.
755+
let v: $ty = v.cast();
756+
return is_full_range.select(v, result);
757+
}
758+
// Replace only the failing lanes
759+
v = accept.select(v, rng.gen());
760+
}
761+
}
717762
}
718763
};
719764

@@ -805,51 +850,6 @@ macro_rules! uniform_simd_int_gt8_impl {
805850
let cast_rand_bits: $ty = rand_bits.cast();
806851
is_full_range.select(cast_rand_bits, low + cast_result)
807852
}
808-
809-
/// Bitmask
810-
#[inline(always)]
811-
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
812-
low_b: B1, high_b: B2, rng: &mut R,
813-
) -> $ty
814-
where
815-
B1: SampleBorrow<$ty> + Sized,
816-
B2: SampleBorrow<$ty> + Sized,
817-
{
818-
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
819-
let is_full_range = range.eq($unsigned::splat(0));
820-
821-
// generate bitmask
822-
range -= 1;
823-
let mut mask = range | 1;
824-
825-
mask |= mask >> 1;
826-
mask |= mask >> 2;
827-
mask |= mask >> 4;
828-
829-
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
830-
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
831-
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
832-
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
833-
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
834-
835-
let mut v: $unsigned = rng.gen();
836-
loop {
837-
let masked = v & mask;
838-
let accept = masked.le(range);
839-
if accept.all() {
840-
let masked: $ty = masked.cast();
841-
// wrapping addition
842-
let result = low + masked;
843-
// `select` here compiles to a blend operation
844-
// When `range.eq(0).none()` the compare and blend
845-
// operations are avoided.
846-
let v: $ty = v.cast();
847-
return is_full_range.select(v, result);
848-
}
849-
// Replace only the failing lanes
850-
v = accept.select(v, rng.gen());
851-
}
852-
}
853853
}
854854
};
855855

@@ -965,51 +965,6 @@ macro_rules! uniform_simd_int_le8_impl {
965965
let cast_rand_bits: $ty = rand_bits.cast();
966966
is_full_range.select(cast_rand_bits, low + cast_result)
967967
}
968-
969-
///
970-
#[inline(always)]
971-
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
972-
low_b: B1, high_b: B2, rng: &mut R,
973-
) -> $ty
974-
where
975-
B1: SampleBorrow<$ty> + Sized,
976-
B2: SampleBorrow<$ty> + Sized,
977-
{
978-
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
979-
let is_full_range = range.eq($unsigned::splat(0));
980-
981-
// generate bitmask
982-
range -= 1;
983-
let mut mask = range | 1;
984-
985-
mask |= mask >> 1;
986-
mask |= mask >> 2;
987-
mask |= mask >> 4;
988-
989-
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
990-
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
991-
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
992-
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
993-
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
994-
995-
let mut v: $unsigned = rng.gen();
996-
loop {
997-
let masked = v & mask;
998-
let accept = masked.le(range);
999-
if accept.all() {
1000-
let masked: $ty = masked.cast();
1001-
// wrapping addition
1002-
let result = low + masked;
1003-
// `select` here compiles to a blend operation
1004-
// When `range.eq(0).none()` the compare and blend
1005-
// operations are avoided.
1006-
let v: $ty = v.cast();
1007-
return is_full_range.select(v, result);
1008-
}
1009-
// Replace only the failing lanes
1010-
v = accept.select(v, rng.gen());
1011-
}
1012-
}
1013968
}
1014969
};
1015970

0 commit comments

Comments
 (0)