Skip to content

Commit d41a948

Browse files
committed
genericize simd uniform int
remove some debug stuff remove bernoulli foo
1 parent e614fd7 commit d41a948

File tree

5 files changed

+39
-82
lines changed

5 files changed

+39
-82
lines changed

benches/misc.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fn misc_bernoulli_const(b: &mut Bencher) {
7575
let d = rand::distributions::Bernoulli::new(0.18).unwrap();
7676
let mut accum = true;
7777
for _ in 0..crate::RAND_BENCH_N {
78-
accum ^= rng.sample::<bool, _>(d);
78+
accum ^= rng.sample(d);
7979
}
8080
accum
8181
})
@@ -89,7 +89,7 @@ fn misc_bernoulli_var(b: &mut Bencher) {
8989
let mut p = 0.18;
9090
for _ in 0..crate::RAND_BENCH_N {
9191
let d = Bernoulli::new(p).unwrap();
92-
accum ^= rng.sample::<bool, _>(d);
92+
accum ^= rng.sample(d);
9393
p += 0.0001;
9494
}
9595
accum

src/distributions/bernoulli.rs

-19
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,6 @@ impl Distribution<bool> for Bernoulli {
154154
}
155155
}
156156

157-
/// Requires nightly Rust and the [`simd_support`] feature
158-
///
159-
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
160-
#[cfg(feature = "simd_support")]
161-
impl<const LANES: usize> Distribution<Mask<i64, LANES>> for Bernoulli
162-
where
163-
LaneCount<LANES>: SupportedLaneCount,
164-
Standard: Distribution<Simd<u64, LANES>>,
165-
{
166-
// TODO: revisit for https://github.com/rust-random/rand/issues/1227
167-
#[inline]
168-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<i64, LANES> {
169-
if self.p_int == ALWAYS_TRUE {
170-
return Mask::splat(true);
171-
}
172-
rng.gen().lanes_lt(Simd::splat(self.p_int))
173-
}
174-
}
175-
176157
#[cfg(test)]
177158
mod test {
178159
use super::Bernoulli;

src/distributions/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ mod float;
9999
mod integer;
100100
mod other;
101101
mod slice;
102-
pub mod utils;
102+
mod utils;
103103
#[cfg(feature = "alloc")]
104104
mod weighted_index;
105105

src/distributions/uniform.rs

+36-59
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ use core::time::Duration;
110110
use core::ops::{Range, RangeInclusive};
111111

112112
use crate::distributions::float::IntoFloat;
113-
use crate::distributions::utils::{BoolAsSIMD, IntAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply};
113+
use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply};
114114
use crate::distributions::Distribution;
115+
#[cfg(feature = "simd_support")]
116+
use crate::distributions::Standard;
115117
use crate::{Rng, RngCore};
116118

117119
#[cfg(not(feature = "std"))]
@@ -571,21 +573,30 @@ uniform_int_impl! { u128, u128, u128 }
571573

572574
#[cfg(feature = "simd_support")]
573575
macro_rules! uniform_simd_int_impl {
574-
($ty:ident, $unsigned:ident, $u_scalar:ident) => {
576+
($ty:ident, $unsigned:ident) => {
575577
// The "pick the largest zone that can fit in an `u32`" optimization
576578
// is less useful here. Multiple lanes complicate things, we don't
577579
// know the PRNG's minimal output size, and casting to a larger vector
578580
// is generally a bad idea for SIMD performance. The user can still
579581
// implement it manually.
580-
581-
// TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality
582-
// perhaps `impl SampleUniform for $u_scalar`?
583-
impl SampleUniform for $ty {
584-
type Sampler = UniformInt<$ty>;
582+
impl<const LANES: usize> SampleUniform for Simd<$ty, LANES>
583+
where
584+
LaneCount<LANES>: SupportedLaneCount,
585+
Simd<$unsigned, LANES>:
586+
WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>,
587+
Standard: Distribution<Simd<$unsigned, LANES>>,
588+
{
589+
type Sampler = UniformInt<Simd<$ty, LANES>>;
585590
}
586591

587-
impl UniformSampler for UniformInt<$ty> {
588-
type X = $ty;
592+
impl<const LANES: usize> UniformSampler for UniformInt<Simd<$ty, LANES>>
593+
where
594+
LaneCount<LANES>: SupportedLaneCount,
595+
Simd<$unsigned, LANES>:
596+
WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>,
597+
Standard: Distribution<Simd<$unsigned, LANES>>,
598+
{
599+
type X = Simd<$ty, LANES>;
589600

590601
#[inline] // if the range is constant, this helps LLVM to do the
591602
// calculations at compile-time.
@@ -609,13 +620,13 @@ macro_rules! uniform_simd_int_impl {
609620
let high = *high_b.borrow();
610621
assert!(low.lanes_le(high).all(),
611622
"Uniform::new_inclusive called with `low > high`");
612-
let unsigned_max = Simd::splat(::core::$u_scalar::MAX);
623+
let unsigned_max = Simd::splat(::core::$unsigned::MAX);
613624

614-
// NOTE: these may need to be replaced with explicitly
615-
// wrapping operations if `packed_simd` changes
616-
let range: $unsigned = ((high - low) + Simd::splat(1)).cast();
625+
// NOTE: all `Simd` operations are inherently wrapping,
626+
// see https://doc.rust-lang.org/std/simd/struct.Simd.html
627+
let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast();
617628
// `% 0` will panic at runtime.
618-
let not_full_range = range.lanes_gt($unsigned::splat(0));
629+
let not_full_range = range.lanes_gt(Simd::splat(0));
619630
// replacing 0 with `unsigned_max` allows a faster `select`
620631
// with bitwise OR
621632
let modulo = not_full_range.select(range, unsigned_max);
@@ -634,8 +645,8 @@ macro_rules! uniform_simd_int_impl {
634645
}
635646

636647
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
637-
let range: $unsigned = self.range.cast();
638-
let zone: $unsigned = self.z.cast();
648+
let range: Simd<$unsigned, LANES> = self.range.cast();
649+
let zone: Simd<$unsigned, LANES> = self.z.cast();
639650

640651
// This might seem very slow, generating a whole new
641652
// SIMD vector for every sample rejection. For most uses
@@ -646,19 +657,19 @@ macro_rules! uniform_simd_int_impl {
646657
// rejection. The replacement method does however add a little
647658
// overhead. Benchmarking or calculating probabilities might
648659
// reveal contexts where this replacement method is slower.
649-
let mut v: $unsigned = rng.gen();
660+
let mut v: Simd<$unsigned, LANES> = rng.gen();
650661
loop {
651662
let (hi, lo) = v.wmul(range);
652663
let mask = lo.lanes_le(zone);
653664
if mask.all() {
654-
let hi: $ty = hi.cast();
665+
let hi: Simd<$ty, LANES> = hi.cast();
655666
// wrapping addition
656667
let result = self.low + hi;
657668
// `select` here compiles to a blend operation
658669
// When `range.eq(0).none()` the compare and blend
659670
// operations are avoided.
660-
let v: $ty = v.cast();
661-
return range.lanes_gt($unsigned::splat(0)).select(result, v);
671+
let v: Simd<$ty, LANES> = v.cast();
672+
return range.lanes_gt(Simd::splat(0)).select(result, v);
662673
}
663674
// Replace only the failing lanes
664675
v = mask.select(v, rng.gen());
@@ -668,50 +679,16 @@ macro_rules! uniform_simd_int_impl {
668679
};
669680

670681
// bulk implementation
671-
($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => {
682+
($(($unsigned:ident, $signed:ident)),+) => {
672683
$(
673-
uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar);
674-
uniform_simd_int_impl!($signed, $unsigned, $u_scalar);
684+
uniform_simd_int_impl!($unsigned, $unsigned);
685+
uniform_simd_int_impl!($signed, $unsigned);
675686
)+
676687
};
677688
}
678689

679690
#[cfg(feature = "simd_support")]
680-
uniform_simd_int_impl! {
681-
(u64x2, i64x2),
682-
(u64x4, i64x4),
683-
(u64x8, i64x8),
684-
u64
685-
}
686-
687-
#[cfg(feature = "simd_support")]
688-
uniform_simd_int_impl! {
689-
(u32x2, i32x2),
690-
(u32x4, i32x4),
691-
(u32x8, i32x8),
692-
(u32x16, i32x16),
693-
u32
694-
}
695-
696-
#[cfg(feature = "simd_support")]
697-
uniform_simd_int_impl! {
698-
(u16x2, i16x2),
699-
(u16x4, i16x4),
700-
(u16x8, i16x8),
701-
(u16x16, i16x16),
702-
(u16x32, i16x32),
703-
u16
704-
}
705-
706-
#[cfg(feature = "simd_support")]
707-
uniform_simd_int_impl! {
708-
(u8x4, i8x4),
709-
(u8x8, i8x8),
710-
(u8x16, i8x16),
711-
(u8x32, i8x32),
712-
(u8x64, i8x64),
713-
u8
714-
}
691+
uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) }
715692

716693
impl SampleUniform for char {
717694
type Sampler = UniformChar;
@@ -1183,7 +1160,7 @@ mod tests {
11831160
_ => panic!("`UniformDurationMode` was not serialized/deserialized correctly")
11841161
}
11851162
}
1186-
1163+
11871164
#[test]
11881165
#[cfg(feature = "serde1")]
11891166
fn test_uniform_serialization() {

src/distributions/utils.rs

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
//! Math helper functions
1010
11-
#[cfg(feature = "simd_support")] use core::mem;
1211
#[cfg(feature = "simd_support")] use core::simd::*;
1312

1413

0 commit comments

Comments
 (0)