Skip to content

implement SIMD UniformInt #561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ default = ["std" ] # without "std" rand uses libcore
nightly = ["simd_support"] # enables all features requiring nightly rust
std = ["rand_core/std", "alloc", "libc", "winapi", "cloudabi", "fuchsia-zircon"]
alloc = ["rand_core/alloc"] # enables Vec and Box support (without std)
i128_support = [] # dummy feature for backwards compatibility
i128_support = [] # enables i128 and u128 support
simd_support = ["packed_simd"] # enables SIMD support
serde1 = ["rand_core/serde1", "rand_isaac/serde1", "rand_xorshift/serde1"] # enables serialization for PRNGs

Expand Down
284 changes: 235 additions & 49 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
//! ```
//! use rand::{Rng, thread_rng};
//! use rand::distributions::Uniform;
//!
//!
//! let mut rng = thread_rng();
//! let side = Uniform::new(-10.0, 10.0);
//!
//!
//! // sample between 1 and 10 points
//! for _ in 0..rng.gen_range(1, 11) {
//! // sample a point from the square with sides -10 - 10 in two dimensions
Expand Down Expand Up @@ -482,6 +482,150 @@ uniform_int_impl! { usize, isize, usize, isize, usize }
#[cfg(rust_1_26)]
uniform_int_impl! { u128, u128, u128, i128, u128 }

#[cfg(feature = "simd_support")]
macro_rules! uniform_simd_int_impl {
($ty:ident, $unsigned:ident, $u_scalar:ident) => {
// The "pick the largest zone that can fit in an `u32`" optimization
// is less useful here. Multiple lanes complicate things, we don't
// know the PRNG's minimal output size, and casting to a larger vector
// is generally a bad idea for SIMD performance. The user can still
// implement it manually.

// TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality
// perhaps `impl SampleUniform for $u_scalar`?
impl SampleUniform for $ty {
type Sampler = UniformInt<$ty>;
}

impl UniformSampler for UniformInt<$ty> {
type X = $ty;

#[inline] // if the range is constant, this helps LLVM to do the
// calculations at compile-time.
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low.lt(high).all(), "Uniform::new called with `low >= high`");
UniformSampler::new_inclusive(low, high - 1)
}

#[inline] // if the range is constant, this helps LLVM to do the
// calculations at compile-time.
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low.le(high).all(),
"Uniform::new_inclusive called with `low > high`");
let unsigned_max = ::core::$u_scalar::MAX;

// NOTE: these may need to be replaced with explicitly
// wrapping operations if `packed_simd` changes
let range: $unsigned = ((high - low) + 1).cast();
// `% 0` will panic at runtime.
let not_full_range = range.gt($unsigned::splat(0));
// replacing 0 with `unsigned_max` allows a faster `select`
// with bitwise OR
let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max));
// wrapping addition
let ints_to_reject = (unsigned_max - range + 1) % modulo;
// When `range` is 0, `lo` of `v.wmul(range)` will always be
// zero which means only one sample is needed.
let zone = unsigned_max - ints_to_reject;

UniformInt {
low: low,
// These are really $unsigned values, but store as $ty:
range: range.cast(),
zone: zone.cast(),
}
}

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
let range: $unsigned = self.range.cast();
let zone: $unsigned = self.zone.cast();

// This might seem very slow, generating a whole new
// SIMD vector for every sample rejection. For most uses
// though, the chance of rejection is small and provides good
// general performance. With multiple lanes, that chance is
// multiplied. To mitigate this, we replace only the lanes of
// the vector which fail, iteratively reducing the chance of
// rejection. The replacement method does however add a little
// overhead. Benchmarking or calculating probabilities might
// reveal contexts where this replacement method is slower.
let mut v: $unsigned = rng.gen();
loop {
let (hi, lo) = v.wmul(range);
let mask = lo.le(zone);
if mask.all() {
let hi: $ty = hi.cast();
// wrapping addition
let result = self.low + hi;
// `select` here compiles to a blend operation
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: $ty = v.cast();
return range.gt($unsigned::splat(0)).select(result, v);
}
// Replace only the failing lanes
v = mask.select(v, rng.gen());
}
}
}
};

// bulk implementation
($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => {
$(
uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar);
uniform_simd_int_impl!($signed, $unsigned, $u_scalar);
)+
};
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u64x2, i64x2),
(u64x4, i64x4),
(u64x8, i64x8),
u64
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u32x2, i32x2),
(u32x4, i32x4),
(u32x8, i32x8),
(u32x16, i32x16),
u32
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u16x2, i16x2),
(u16x4, i16x4),
(u16x8, i16x8),
(u16x16, i16x16),
(u16x32, i16x32),
u16
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u8x2, i8x2),
(u8x4, i8x4),
(u8x8, i8x8),
(u8x16, i8x16),
(u8x32, i8x32),
(u8x64, i8x64),
u8
}


/// The back-end implementing [`UniformSampler`] for floating-point types.
Expand Down Expand Up @@ -817,50 +961,86 @@ mod tests {

#[test]
fn test_integers() {
use core::{i8, i16, i32, i64, isize};
use core::{u8, u16, u32, u64, usize};
#[cfg(rust_1_26)]
use core::{i128, u128};

let mut rng = ::test::rng(251);
macro_rules! t {
($($ty:ident),*) => {{
$(
let v: &[($ty, $ty)] = &[(0, 10),
(10, 127),
(::core::$ty::MIN, ::core::$ty::MAX)];
for &(low, high) in v.iter() {
let my_uniform = Uniform::new(low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v < high);
}
($ty:ident, $v:expr, $le:expr, $lt:expr) => {{
for &(low, high) in $v.iter() {
let my_uniform = Uniform::new(low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!($le(low, v) && $lt(v, high));
}

let my_uniform = Uniform::new_inclusive(low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v <= high);
}
let my_uniform = Uniform::new_inclusive(low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!($le(low, v) && $le(v, high));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could introduce something like FloatSIMDUtils and FloatAsSIMD::splat, but for integer types. That would allow a lot of this test code to be written more nicely.

But I'll defer to @dhardy and @pitdicker.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but what's written here is sufficient.

}

let my_uniform = Uniform::new(&low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v < high);
}
let my_uniform = Uniform::new(&low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!($le(low, v) && $lt(v, high));
}

let my_uniform = Uniform::new_inclusive(&low, &high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v <= high);
}
let my_uniform = Uniform::new_inclusive(&low, &high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!($le(low, v) && $le(v, high));
}

for _ in 0..1000 {
let v: $ty = rng.gen_range(low, high);
assert!(low <= v && v < high);
}
for _ in 0..1000 {
let v: $ty = rng.gen_range(low, high);
assert!($le(low, v) && $lt(v, high));
}
)*
}}
}
}};

// scalar bulk
($($ty:ident),*) => {{
$(t!(
$ty,
[(0, 10), (10, 127), ($ty::MIN, $ty::MAX)],
|x, y| x <= y,
|x, y| x < y
);)*
}};

// simd bulk
($($ty:ident),* => $scalar:ident) => {{
$(t!(
$ty,
[
($ty::splat(0), $ty::splat(10)),
($ty::splat(10), $ty::splat(127)),
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
],
|x: $ty, y| x.le(y).all(),
|x: $ty, y| x.lt(y).all()
);)*
}};
}
t!(i8, i16, i32, i64, isize,
u8, u16, u32, u64, usize);
#[cfg(rust_1_26)]
t!(i128, u128)
t!(i128, u128);

#[cfg(feature = "simd_support")]
{
t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8);
t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8);
t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16);
t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16);
t!(u32x2, u32x4, u32x8, u32x16 => u32);
t!(i32x2, i32x4, i32x8, i32x16 => i32);
t!(u64x2, u64x4, u64x8 => u64);
t!(i64x2, i64x4, i64x8 => i64);
}
}

#[test]
Expand Down Expand Up @@ -932,13 +1112,16 @@ mod tests {

t!(f32, f32, 32 - 23);
t!(f64, f64, 64 - 52);
#[cfg(feature="simd_support")] t!(f32x2, f32, 32 - 23);
#[cfg(feature="simd_support")] t!(f32x4, f32, 32 - 23);
#[cfg(feature="simd_support")] t!(f32x8, f32, 32 - 23);
#[cfg(feature="simd_support")] t!(f32x16, f32, 32 - 23);
#[cfg(feature="simd_support")] t!(f64x2, f64, 64 - 52);
#[cfg(feature="simd_support")] t!(f64x4, f64, 64 - 52);
#[cfg(feature="simd_support")] t!(f64x8, f64, 64 - 52);
#[cfg(feature="simd_support")]
{
t!(f32x2, f32, 32 - 23);
t!(f32x4, f32, 32 - 23);
t!(f32x8, f32, 32 - 23);
t!(f32x16, f32, 32 - 23);
t!(f64x2, f64, 64 - 52);
t!(f64x4, f64, 64 - 52);
t!(f64x8, f64, 64 - 52);
}
}

#[test]
Expand Down Expand Up @@ -985,13 +1168,16 @@ mod tests {

t!(f32, f32);
t!(f64, f64);
#[cfg(feature="simd_support")] t!(f32x2, f32);
#[cfg(feature="simd_support")] t!(f32x4, f32);
#[cfg(feature="simd_support")] t!(f32x8, f32);
#[cfg(feature="simd_support")] t!(f32x16, f32);
#[cfg(feature="simd_support")] t!(f64x2, f64);
#[cfg(feature="simd_support")] t!(f64x4, f64);
#[cfg(feature="simd_support")] t!(f64x8, f64);
#[cfg(feature="simd_support")]
{
t!(f32x2, f32);
t!(f32x4, f32);
t!(f32x8, f32);
t!(f32x16, f32);
t!(f64x2, f64);
t!(f64x4, f64);
t!(f64x8, f64);
}
}


Expand Down
Loading