Skip to content

Commit e614fd7

Browse files
committed
switch to std::simd, expand SIMD stuff & docs
move __m128i to stable, expand documentation, add SIMD to Bernoulli, add maskNxM, add __m512i
1 parent 3543f4b commit e614fd7

File tree

10 files changed

+287
-180
lines changed

10 files changed

+287
-180
lines changed

Cargo.toml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ alloc = ["rand_core/alloc"]
4141
# Option: use getrandom package for seeding
4242
getrandom = ["rand_core/getrandom"]
4343

44-
# Option (requires nightly): experimental SIMD support
45-
simd_support = ["packed_simd"]
44+
# Option (requires nightly Rust): experimental SIMD support
45+
simd_support = []
4646

4747
# Option (enabled by default): enable StdRng
4848
std_rng = ["rand_chacha"]
@@ -68,13 +68,6 @@ log = { version = "0.4.4", optional = true }
6868
serde = { version = "1.0.103", features = ["derive"], optional = true }
6969
rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true }
7070

71-
[dependencies.packed_simd]
72-
# NOTE: so far no version works reliably due to dependence on unstable features
73-
package = "packed_simd_2"
74-
version = "0.3.7"
75-
optional = true
76-
features = ["into_bits"]
77-
7871
[target.'cfg(unix)'.dependencies]
7972
# Used for fork protection (reseeding.rs)
8073
libc = { version = "0.2.22", optional = true, default-features = false }

benches/misc.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fn misc_bernoulli_const(b: &mut Bencher) {
7575
let d = rand::distributions::Bernoulli::new(0.18).unwrap();
7676
let mut accum = true;
7777
for _ in 0..crate::RAND_BENCH_N {
78-
accum ^= rng.sample(d);
78+
accum ^= rng.sample::<bool, _>(d);
7979
}
8080
accum
8181
})
@@ -89,7 +89,7 @@ fn misc_bernoulli_var(b: &mut Bencher) {
8989
let mut p = 0.18;
9090
for _ in 0..crate::RAND_BENCH_N {
9191
let d = Bernoulli::new(p).unwrap();
92-
accum ^= rng.sample(d);
92+
accum ^= rng.sample::<bool, _>(d);
9393
p += 0.0001;
9494
}
9595
accum

src/distributions/bernoulli.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
1111
use crate::distributions::Distribution;
1212
use crate::Rng;
13+
#[cfg(feature = "simd_support")]
14+
use crate::distributions::Standard;
15+
#[cfg(feature = "simd_support")]
16+
use core::simd::{LaneCount, Mask, Simd, SupportedLaneCount};
1317
use core::{fmt, u64};
1418

1519
#[cfg(feature = "serde1")]
1620
use serde::{Serialize, Deserialize};
21+
1722
/// The Bernoulli distribution.
1823
///
1924
/// This is a special case of the Binomial distribution where `n = 1`.
@@ -24,7 +29,7 @@ use serde::{Serialize, Deserialize};
2429
/// use rand::distributions::{Bernoulli, Distribution};
2530
///
2631
/// let d = Bernoulli::new(0.3).unwrap();
27-
/// let v = d.sample(&mut rand::thread_rng());
32+
/// let v: bool = d.sample(&mut rand::thread_rng());
2833
/// println!("{} is from a Bernoulli distribution", v);
2934
/// ```
3035
///
@@ -33,6 +38,15 @@ use serde::{Serialize, Deserialize};
3338
/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`),
3439
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
3540
/// represented.
41+
///
42+
/// # SIMD
43+
///
44+
/// On nightly Rust and with the [`simd_support`] feature this distribution
45+
/// can also generate multiple samples at once in the form of `std::simd`'s
46+
/// [`maskNxM`](core::simd::Mask) types. Each lane of the mask uses the same
47+
/// probability.
48+
///
49+
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
3650
#[derive(Clone, Copy, Debug, PartialEq)]
3751
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
3852
pub struct Bernoulli {
@@ -140,17 +154,36 @@ impl Distribution<bool> for Bernoulli {
140154
}
141155
}
142156

157+
/// Requires nightly Rust and the [`simd_support`] feature
158+
///
159+
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
160+
#[cfg(feature = "simd_support")]
161+
impl<const LANES: usize> Distribution<Mask<i64, LANES>> for Bernoulli
162+
where
163+
LaneCount<LANES>: SupportedLaneCount,
164+
Standard: Distribution<Simd<u64, LANES>>,
165+
{
166+
// TODO: revisit for https://github.com/rust-random/rand/issues/1227
167+
#[inline]
168+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<i64, LANES> {
169+
if self.p_int == ALWAYS_TRUE {
170+
return Mask::splat(true);
171+
}
172+
rng.gen().lanes_lt(Simd::splat(self.p_int))
173+
}
174+
}
175+
143176
#[cfg(test)]
144177
mod test {
145178
use super::Bernoulli;
146179
use crate::distributions::Distribution;
147180
use crate::Rng;
148181

149182
#[test]
150-
#[cfg(feature="serde1")]
183+
#[cfg(feature = "serde1")]
151184
fn test_serializing_deserializing_bernoulli() {
152185
let coin_flip = Bernoulli::new(0.5).unwrap();
153-
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
186+
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
154187

155188
assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
156189
}

src/distributions/float.rs

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
//! Basic floating-point number distributions
1010
11-
use crate::distributions::utils::FloatSIMDUtils;
11+
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
1212
use crate::distributions::{Distribution, Standard};
1313
use crate::Rng;
1414
use core::mem;
15-
#[cfg(feature = "simd_support")] use packed_simd::*;
15+
#[cfg(feature = "simd_support")] use core::simd::*;
1616

1717
#[cfg(feature = "serde1")]
1818
use serde::{Serialize, Deserialize};
@@ -99,7 +99,7 @@ macro_rules! float_impls {
9999
// The exponent is encoded using an offset-binary representation
100100
let exponent_bits: $u_scalar =
101101
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
102-
$ty::from_bits(self | exponent_bits)
102+
$ty::from_bits(self | $uty::splat(exponent_bits))
103103
}
104104
}
105105

@@ -108,13 +108,13 @@ macro_rules! float_impls {
108108
// Multiply-based method; 24/53 random bits; [0, 1) interval.
109109
// We use the most significant bits because for simple RNGs
110110
// those are usually more random.
111-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
111+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
112112
let precision = $fraction_bits + 1;
113113
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
114114

115115
let value: $uty = rng.gen();
116-
let value = value >> (float_size - precision);
117-
scale * $ty::cast_from_int(value)
116+
let value = value >> $uty::splat(float_size - precision);
117+
$ty::splat(scale) * $ty::cast_from_int(value)
118118
}
119119
}
120120

@@ -123,14 +123,14 @@ macro_rules! float_impls {
123123
// Multiply-based method; 24/53 random bits; (0, 1] interval.
124124
// We use the most significant bits because for simple RNGs
125125
// those are usually more random.
126-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
126+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
127127
let precision = $fraction_bits + 1;
128128
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
129129

130130
let value: $uty = rng.gen();
131-
let value = value >> (float_size - precision);
131+
let value = value >> $uty::splat(float_size - precision);
132132
// Add 1 to shift up; will not overflow because of right-shift:
133-
scale * $ty::cast_from_int(value + 1)
133+
$ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1))
134134
}
135135
}
136136

@@ -140,11 +140,11 @@ macro_rules! float_impls {
140140
// We use the most significant bits because for simple RNGs
141141
// those are usually more random.
142142
use core::$f_scalar::EPSILON;
143-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
143+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
144144

145145
let value: $uty = rng.gen();
146-
let fraction = value >> (float_size - $fraction_bits);
147-
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
146+
let fraction = value >> $uty::splat(float_size - $fraction_bits);
147+
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
148148
}
149149
}
150150
}
@@ -169,10 +169,10 @@ float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
169169
#[cfg(feature = "simd_support")]
170170
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
171171

172-
173172
#[cfg(test)]
174173
mod tests {
175174
use super::*;
175+
use crate::distributions::utils::FloatAsSIMD;
176176
use crate::rngs::mock::StepRng;
177177

178178
const EPSILON32: f32 = ::core::f32::EPSILON;
@@ -182,29 +182,31 @@ mod tests {
182182
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
183183
#[test]
184184
fn $fnn() {
185+
let two = $ty::splat(2.0);
186+
185187
// Standard
186188
let mut zeros = StepRng::new(0, 0);
187189
assert_eq!(zeros.gen::<$ty>(), $ZERO);
188190
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
189-
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
191+
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
190192
let mut max = StepRng::new(!0, 0);
191-
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
193+
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);
192194

193195
// OpenClosed01
194196
let mut zeros = StepRng::new(0, 0);
195-
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
197+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
196198
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
197199
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
198200
let mut max = StepRng::new(!0, 0);
199-
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
201+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));
200202

201203
// Open01
202204
let mut zeros = StepRng::new(0, 0);
203-
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
205+
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
204206
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
205-
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
207+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
206208
let mut max = StepRng::new(!0, 0);
207-
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
209+
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
208210
}
209211
};
210212
}
@@ -222,29 +224,31 @@ mod tests {
222224
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
223225
#[test]
224226
fn $fnn() {
227+
let two = $ty::splat(2.0);
228+
225229
// Standard
226230
let mut zeros = StepRng::new(0, 0);
227231
assert_eq!(zeros.gen::<$ty>(), $ZERO);
228232
let mut one = StepRng::new(1 << 11, 0);
229-
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
233+
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
230234
let mut max = StepRng::new(!0, 0);
231-
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
235+
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);
232236

233237
// OpenClosed01
234238
let mut zeros = StepRng::new(0, 0);
235-
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
239+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
236240
let mut one = StepRng::new(1 << 11, 0);
237241
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
238242
let mut max = StepRng::new(!0, 0);
239-
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
243+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));
240244

241245
// Open01
242246
let mut zeros = StepRng::new(0, 0);
243-
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
247+
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
244248
let mut one = StepRng::new(1 << 12, 0);
245-
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
249+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
246250
let mut max = StepRng::new(!0, 0);
247-
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
251+
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
248252
}
249253
};
250254
}
@@ -296,16 +300,16 @@ mod tests {
296300
// non-SIMD types; we assume this pattern continues across all
297301
// SIMD types.
298302

299-
test_samples(&Standard, f32x2::new(0.0, 0.0), &[
300-
f32x2::new(0.0035963655, 0.7346052),
301-
f32x2::new(0.09778172, 0.20298547),
302-
f32x2::new(0.34296435, 0.81664366),
303+
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[
304+
f32x2::from([0.0035963655, 0.7346052]),
305+
f32x2::from([0.09778172, 0.20298547]),
306+
f32x2::from([0.34296435, 0.81664366]),
303307
]);
304308

305-
test_samples(&Standard, f64x2::new(0.0, 0.0), &[
306-
f64x2::new(0.7346051961657583, 0.20298547462974248),
307-
f64x2::new(0.8166436635290655, 0.7423708925400552),
308-
f64x2::new(0.16387782224016323, 0.9087068770169618),
309+
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[
310+
f64x2::from([0.7346051961657583, 0.20298547462974248]),
311+
f64x2::from([0.8166436635290655, 0.7423708925400552]),
312+
f64x2::from([0.16387782224016323, 0.9087068770169618]),
309313
]);
310314
}
311315
}

0 commit comments

Comments
 (0)