Skip to content

Commit fa269c3

Browse files
authored
Merge pull request #561 from TheIronBorn/patch-6
implement SIMD UniformInt
2 parents af8aa52 + b45e54f commit fa269c3

File tree

3 files changed

+360
-52
lines changed

3 files changed

+360
-52
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ default = ["std" ] # without "std" rand uses libcore
2323
nightly = ["simd_support"] # enables all features requiring nightly rust
2424
std = ["rand_core/std", "alloc", "libc", "winapi", "cloudabi", "fuchsia-zircon"]
2525
alloc = ["rand_core/alloc"] # enables Vec and Box support (without std)
26-
i128_support = [] # dummy feature for backwards compatibility
26+
i128_support = [] # enables i128 and u128 support
2727
simd_support = ["packed_simd"] # enables SIMD support
2828
serde1 = ["rand_core/serde1", "rand_isaac/serde1", "rand_xorshift/serde1"] # enables serialization for PRNGs
2929

src/distributions/uniform.rs

+235-49
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
//! ```
3131
//! use rand::{Rng, thread_rng};
3232
//! use rand::distributions::Uniform;
33-
//!
33+
//!
3434
//! let mut rng = thread_rng();
3535
//! let side = Uniform::new(-10.0, 10.0);
36-
//!
36+
//!
3737
//! // sample between 1 and 10 points
3838
//! for _ in 0..rng.gen_range(1, 11) {
3939
//! // sample a point from the square with sides -10 - 10 in two dimensions
@@ -482,6 +482,150 @@ uniform_int_impl! { usize, isize, usize, isize, usize }
482482
#[cfg(rust_1_26)]
483483
uniform_int_impl! { u128, u128, u128, i128, u128 }
484484

485+
#[cfg(feature = "simd_support")]
486+
macro_rules! uniform_simd_int_impl {
487+
($ty:ident, $unsigned:ident, $u_scalar:ident) => {
488+
// The "pick the largest zone that can fit in an `u32`" optimization
489+
// is less useful here. Multiple lanes complicate things, we don't
490+
// know the PRNG's minimal output size, and casting to a larger vector
491+
// is generally a bad idea for SIMD performance. The user can still
492+
// implement it manually.
493+
494+
// TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality
495+
// perhaps `impl SampleUniform for $u_scalar`?
496+
impl SampleUniform for $ty {
497+
type Sampler = UniformInt<$ty>;
498+
}
499+
500+
impl UniformSampler for UniformInt<$ty> {
501+
type X = $ty;
502+
503+
#[inline] // if the range is constant, this helps LLVM to do the
504+
// calculations at compile-time.
505+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
506+
where B1: SampleBorrow<Self::X> + Sized,
507+
B2: SampleBorrow<Self::X> + Sized
508+
{
509+
let low = *low_b.borrow();
510+
let high = *high_b.borrow();
511+
assert!(low.lt(high).all(), "Uniform::new called with `low >= high`");
512+
UniformSampler::new_inclusive(low, high - 1)
513+
}
514+
515+
#[inline] // if the range is constant, this helps LLVM to do the
516+
// calculations at compile-time.
517+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
518+
where B1: SampleBorrow<Self::X> + Sized,
519+
B2: SampleBorrow<Self::X> + Sized
520+
{
521+
let low = *low_b.borrow();
522+
let high = *high_b.borrow();
523+
assert!(low.le(high).all(),
524+
"Uniform::new_inclusive called with `low > high`");
525+
let unsigned_max = ::core::$u_scalar::MAX;
526+
527+
// NOTE: these may need to be replaced with explicitly
528+
// wrapping operations if `packed_simd` changes
529+
let range: $unsigned = ((high - low) + 1).cast();
530+
// `% 0` will panic at runtime.
531+
let not_full_range = range.gt($unsigned::splat(0));
532+
// replacing 0 with `unsigned_max` allows a faster `select`
533+
// with bitwise OR
534+
let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max));
535+
// wrapping addition
536+
let ints_to_reject = (unsigned_max - range + 1) % modulo;
537+
// When `range` is 0, `lo` of `v.wmul(range)` will always be
538+
// zero which means only one sample is needed.
539+
let zone = unsigned_max - ints_to_reject;
540+
541+
UniformInt {
542+
low: low,
543+
// These are really $unsigned values, but store as $ty:
544+
range: range.cast(),
545+
zone: zone.cast(),
546+
}
547+
}
548+
549+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
550+
let range: $unsigned = self.range.cast();
551+
let zone: $unsigned = self.zone.cast();
552+
553+
// This might seem very slow, generating a whole new
554+
// SIMD vector for every sample rejection. For most uses
555+
// though, the chance of rejection is small and provides good
556+
// general performance. With multiple lanes, that chance is
557+
// multiplied. To mitigate this, we replace only the lanes of
558+
// the vector which fail, iteratively reducing the chance of
559+
// rejection. The replacement method does however add a little
560+
// overhead. Benchmarking or calculating probabilities might
561+
// reveal contexts where this replacement method is slower.
562+
let mut v: $unsigned = rng.gen();
563+
loop {
564+
let (hi, lo) = v.wmul(range);
565+
let mask = lo.le(zone);
566+
if mask.all() {
567+
let hi: $ty = hi.cast();
568+
// wrapping addition
569+
let result = self.low + hi;
570+
// `select` here compiles to a blend operation
571+
// When `range.eq(0).none()` the compare and blend
572+
// operations are avoided.
573+
let v: $ty = v.cast();
574+
return range.gt($unsigned::splat(0)).select(result, v);
575+
}
576+
// Replace only the failing lanes
577+
v = mask.select(v, rng.gen());
578+
}
579+
}
580+
}
581+
};
582+
583+
// bulk implementation
584+
($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => {
585+
$(
586+
uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar);
587+
uniform_simd_int_impl!($signed, $unsigned, $u_scalar);
588+
)+
589+
};
590+
}
591+
592+
#[cfg(feature = "simd_support")]
593+
uniform_simd_int_impl! {
594+
(u64x2, i64x2),
595+
(u64x4, i64x4),
596+
(u64x8, i64x8),
597+
u64
598+
}
599+
600+
#[cfg(feature = "simd_support")]
601+
uniform_simd_int_impl! {
602+
(u32x2, i32x2),
603+
(u32x4, i32x4),
604+
(u32x8, i32x8),
605+
(u32x16, i32x16),
606+
u32
607+
}
608+
609+
#[cfg(feature = "simd_support")]
610+
uniform_simd_int_impl! {
611+
(u16x2, i16x2),
612+
(u16x4, i16x4),
613+
(u16x8, i16x8),
614+
(u16x16, i16x16),
615+
(u16x32, i16x32),
616+
u16
617+
}
618+
619+
#[cfg(feature = "simd_support")]
620+
uniform_simd_int_impl! {
621+
(u8x2, i8x2),
622+
(u8x4, i8x4),
623+
(u8x8, i8x8),
624+
(u8x16, i8x16),
625+
(u8x32, i8x32),
626+
(u8x64, i8x64),
627+
u8
628+
}
485629

486630

487631
/// The back-end implementing [`UniformSampler`] for floating-point types.
@@ -817,50 +961,86 @@ mod tests {
817961

818962
#[test]
819963
fn test_integers() {
964+
use core::{i8, i16, i32, i64, isize};
965+
use core::{u8, u16, u32, u64, usize};
966+
#[cfg(rust_1_26)]
967+
use core::{i128, u128};
968+
820969
let mut rng = ::test::rng(251);
821970
macro_rules! t {
822-
($($ty:ident),*) => {{
823-
$(
824-
let v: &[($ty, $ty)] = &[(0, 10),
825-
(10, 127),
826-
(::core::$ty::MIN, ::core::$ty::MAX)];
827-
for &(low, high) in v.iter() {
828-
let my_uniform = Uniform::new(low, high);
829-
for _ in 0..1000 {
830-
let v: $ty = rng.sample(my_uniform);
831-
assert!(low <= v && v < high);
832-
}
971+
($ty:ident, $v:expr, $le:expr, $lt:expr) => {{
972+
for &(low, high) in $v.iter() {
973+
let my_uniform = Uniform::new(low, high);
974+
for _ in 0..1000 {
975+
let v: $ty = rng.sample(my_uniform);
976+
assert!($le(low, v) && $lt(v, high));
977+
}
833978

834-
let my_uniform = Uniform::new_inclusive(low, high);
835-
for _ in 0..1000 {
836-
let v: $ty = rng.sample(my_uniform);
837-
assert!(low <= v && v <= high);
838-
}
979+
let my_uniform = Uniform::new_inclusive(low, high);
980+
for _ in 0..1000 {
981+
let v: $ty = rng.sample(my_uniform);
982+
assert!($le(low, v) && $le(v, high));
983+
}
839984

840-
let my_uniform = Uniform::new(&low, high);
841-
for _ in 0..1000 {
842-
let v: $ty = rng.sample(my_uniform);
843-
assert!(low <= v && v < high);
844-
}
985+
let my_uniform = Uniform::new(&low, high);
986+
for _ in 0..1000 {
987+
let v: $ty = rng.sample(my_uniform);
988+
assert!($le(low, v) && $lt(v, high));
989+
}
845990

846-
let my_uniform = Uniform::new_inclusive(&low, &high);
847-
for _ in 0..1000 {
848-
let v: $ty = rng.sample(my_uniform);
849-
assert!(low <= v && v <= high);
850-
}
991+
let my_uniform = Uniform::new_inclusive(&low, &high);
992+
for _ in 0..1000 {
993+
let v: $ty = rng.sample(my_uniform);
994+
assert!($le(low, v) && $le(v, high));
995+
}
851996

852-
for _ in 0..1000 {
853-
let v: $ty = rng.gen_range(low, high);
854-
assert!(low <= v && v < high);
855-
}
997+
for _ in 0..1000 {
998+
let v: $ty = rng.gen_range(low, high);
999+
assert!($le(low, v) && $lt(v, high));
8561000
}
857-
)*
858-
}}
1001+
}
1002+
}};
1003+
1004+
// scalar bulk
1005+
($($ty:ident),*) => {{
1006+
$(t!(
1007+
$ty,
1008+
[(0, 10), (10, 127), ($ty::MIN, $ty::MAX)],
1009+
|x, y| x <= y,
1010+
|x, y| x < y
1011+
);)*
1012+
}};
1013+
1014+
// simd bulk
1015+
($($ty:ident),* => $scalar:ident) => {{
1016+
$(t!(
1017+
$ty,
1018+
[
1019+
($ty::splat(0), $ty::splat(10)),
1020+
($ty::splat(10), $ty::splat(127)),
1021+
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
1022+
],
1023+
|x: $ty, y| x.le(y).all(),
1024+
|x: $ty, y| x.lt(y).all()
1025+
);)*
1026+
}};
8591027
}
8601028
t!(i8, i16, i32, i64, isize,
8611029
u8, u16, u32, u64, usize);
8621030
#[cfg(rust_1_26)]
863-
t!(i128, u128)
1031+
t!(i128, u128);
1032+
1033+
#[cfg(feature = "simd_support")]
1034+
{
1035+
t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8);
1036+
t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8);
1037+
t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16);
1038+
t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16);
1039+
t!(u32x2, u32x4, u32x8, u32x16 => u32);
1040+
t!(i32x2, i32x4, i32x8, i32x16 => i32);
1041+
t!(u64x2, u64x4, u64x8 => u64);
1042+
t!(i64x2, i64x4, i64x8 => i64);
1043+
}
8641044
}
8651045

8661046
#[test]
@@ -932,13 +1112,16 @@ mod tests {
9321112

9331113
t!(f32, f32, 32 - 23);
9341114
t!(f64, f64, 64 - 52);
935-
#[cfg(feature="simd_support")] t!(f32x2, f32, 32 - 23);
936-
#[cfg(feature="simd_support")] t!(f32x4, f32, 32 - 23);
937-
#[cfg(feature="simd_support")] t!(f32x8, f32, 32 - 23);
938-
#[cfg(feature="simd_support")] t!(f32x16, f32, 32 - 23);
939-
#[cfg(feature="simd_support")] t!(f64x2, f64, 64 - 52);
940-
#[cfg(feature="simd_support")] t!(f64x4, f64, 64 - 52);
941-
#[cfg(feature="simd_support")] t!(f64x8, f64, 64 - 52);
1115+
#[cfg(feature="simd_support")]
1116+
{
1117+
t!(f32x2, f32, 32 - 23);
1118+
t!(f32x4, f32, 32 - 23);
1119+
t!(f32x8, f32, 32 - 23);
1120+
t!(f32x16, f32, 32 - 23);
1121+
t!(f64x2, f64, 64 - 52);
1122+
t!(f64x4, f64, 64 - 52);
1123+
t!(f64x8, f64, 64 - 52);
1124+
}
9421125
}
9431126

9441127
#[test]
@@ -985,13 +1168,16 @@ mod tests {
9851168

9861169
t!(f32, f32);
9871170
t!(f64, f64);
988-
#[cfg(feature="simd_support")] t!(f32x2, f32);
989-
#[cfg(feature="simd_support")] t!(f32x4, f32);
990-
#[cfg(feature="simd_support")] t!(f32x8, f32);
991-
#[cfg(feature="simd_support")] t!(f32x16, f32);
992-
#[cfg(feature="simd_support")] t!(f64x2, f64);
993-
#[cfg(feature="simd_support")] t!(f64x4, f64);
994-
#[cfg(feature="simd_support")] t!(f64x8, f64);
1171+
#[cfg(feature="simd_support")]
1172+
{
1173+
t!(f32x2, f32);
1174+
t!(f32x4, f32);
1175+
t!(f32x8, f32);
1176+
t!(f32x16, f32);
1177+
t!(f64x2, f64);
1178+
t!(f64x4, f64);
1179+
t!(f64x8, f64);
1180+
}
9951181
}
9961182

9971183

0 commit comments

Comments
 (0)