Skip to content

Commit 9d0c88a

Browse files
committed
dedup SIMD sample_single_inclusive_bitmask
1 parent a97199b commit 9d0c88a

File tree

1 file changed

+45
-90
lines changed

1 file changed

+45
-90
lines changed

src/distributions/uniform/uniform_int.rs

Lines changed: 45 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,51 @@ macro_rules! uniform_simd_int_impl {
715715
let range: $unsigned = ((high - low) + 1).cast();
716716
(range, low)
717717
}
718+
719+
///
720+
#[inline(always)]
721+
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
722+
low_b: B1, high_b: B2, rng: &mut R,
723+
) -> $ty
724+
where
725+
B1: SampleBorrow<$ty> + Sized,
726+
B2: SampleBorrow<$ty> + Sized,
727+
{
728+
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
729+
let is_full_range = range.eq($unsigned::splat(0));
730+
731+
// generate bitmask
732+
range -= 1;
733+
let mut mask = range | 1;
734+
735+
mask |= mask >> 1;
736+
mask |= mask >> 2;
737+
mask |= mask >> 4;
738+
739+
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
740+
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
741+
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
742+
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
743+
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
744+
745+
let mut v: $unsigned = rng.gen();
746+
loop {
747+
let masked = v & mask;
748+
let accept = masked.le(range);
749+
if accept.all() {
750+
let masked: $ty = masked.cast();
751+
// wrapping addition
752+
let result = low + masked;
753+
// `select` here compiles to a blend operation
754+
// When `range.eq(0).none()` the compare and blend
755+
// operations are avoided.
756+
let v: $ty = v.cast();
757+
return is_full_range.select(v, result);
758+
}
759+
// Replace only the failing lanes
760+
v = accept.select(v, rng.gen());
761+
}
762+
}
718763
}
719764
};
720765

@@ -806,51 +851,6 @@ macro_rules! uniform_simd_int_gt8_impl {
806851
let cast_rand_bits: $ty = rand_bits.cast();
807852
is_full_range.select(cast_rand_bits, low + cast_result)
808853
}
809-
810-
/// Bitmask
811-
#[inline(always)]
812-
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
813-
low_b: B1, high_b: B2, rng: &mut R,
814-
) -> $ty
815-
where
816-
B1: SampleBorrow<$ty> + Sized,
817-
B2: SampleBorrow<$ty> + Sized,
818-
{
819-
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
820-
let is_full_range = range.eq($unsigned::splat(0));
821-
822-
// generate bitmask
823-
range -= 1;
824-
let mut mask = range | 1;
825-
826-
mask |= mask >> 1;
827-
mask |= mask >> 2;
828-
mask |= mask >> 4;
829-
830-
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
831-
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
832-
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
833-
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
834-
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
835-
836-
let mut v: $unsigned = rng.gen();
837-
loop {
838-
let masked = v & mask;
839-
let accept = masked.le(range);
840-
if accept.all() {
841-
let masked: $ty = masked.cast();
842-
// wrapping addition
843-
let result = low + masked;
844-
// `select` here compiles to a blend operation
845-
// When `range.eq(0).none()` the compare and blend
846-
// operations are avoided.
847-
let v: $ty = v.cast();
848-
return is_full_range.select(v, result);
849-
}
850-
// Replace only the failing lanes
851-
v = accept.select(v, rng.gen());
852-
}
853-
}
854854
}
855855
};
856856

@@ -966,51 +966,6 @@ macro_rules! uniform_simd_int_le8_impl {
966966
let cast_rand_bits: $ty = rand_bits.cast();
967967
is_full_range.select(cast_rand_bits, low + cast_result)
968968
}
969-
970-
///
971-
#[inline(always)]
972-
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
973-
low_b: B1, high_b: B2, rng: &mut R,
974-
) -> $ty
975-
where
976-
B1: SampleBorrow<$ty> + Sized,
977-
B2: SampleBorrow<$ty> + Sized,
978-
{
979-
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
980-
let is_full_range = range.eq($unsigned::splat(0));
981-
982-
// generate bitmask
983-
range -= 1;
984-
let mut mask = range | 1;
985-
986-
mask |= mask >> 1;
987-
mask |= mask >> 2;
988-
mask |= mask >> 4;
989-
990-
const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
991-
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
992-
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
993-
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
994-
if LANE_WIDTH >= 128 { mask |= mask >> 64; }
995-
996-
let mut v: $unsigned = rng.gen();
997-
loop {
998-
let masked = v & mask;
999-
let accept = masked.le(range);
1000-
if accept.all() {
1001-
let masked: $ty = masked.cast();
1002-
// wrapping addition
1003-
let result = low + masked;
1004-
// `select` here compiles to a blend operation
1005-
// When `range.eq(0).none()` the compare and blend
1006-
// operations are avoided.
1007-
let v: $ty = v.cast();
1008-
return is_full_range.select(v, result);
1009-
}
1010-
// Replace only the failing lanes
1011-
v = accept.select(v, rng.gen());
1012-
}
1013-
}
1014969
}
1015970
};
1016971

0 commit comments

Comments
 (0)