Skip to content

Commit 599d7f8

Browse files
committed
switch to std::simd, expand SIMD stuff & docs
move __m128i to stable, expand documentation, add SIMD to Bernoulli, add maskNxM, add __m512i genericize simd uniform int remove some debug stuff remove bernoulli foo foo
1 parent 3543f4b commit 599d7f8

File tree

9 files changed

+281
-230
lines changed

9 files changed

+281
-230
lines changed

Cargo.toml

+2-9
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ alloc = ["rand_core/alloc"]
4141
# Option: use getrandom package for seeding
4242
getrandom = ["rand_core/getrandom"]
4343

44-
# Option (requires nightly): experimental SIMD support
45-
simd_support = ["packed_simd"]
44+
# Option (requires nightly Rust): experimental SIMD support
45+
simd_support = []
4646

4747
# Option (enabled by default): enable StdRng
4848
std_rng = ["rand_chacha"]
@@ -68,13 +68,6 @@ log = { version = "0.4.4", optional = true }
6868
serde = { version = "1.0.103", features = ["derive"], optional = true }
6969
rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true }
7070

71-
[dependencies.packed_simd]
72-
# NOTE: so far no version works reliably due to dependence on unstable features
73-
package = "packed_simd_2"
74-
version = "0.3.7"
75-
optional = true
76-
features = ["into_bits"]
77-
7871
[target.'cfg(unix)'.dependencies]
7972
# Used for fork protection (reseeding.rs)
8073
libc = { version = "0.2.22", optional = true, default-features = false }

src/distributions/bernoulli.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use core::{fmt, u64};
1414

1515
#[cfg(feature = "serde1")]
1616
use serde::{Serialize, Deserialize};
17+
1718
/// The Bernoulli distribution.
1819
///
1920
/// This is a special case of the Binomial distribution where `n = 1`.
@@ -147,10 +148,10 @@ mod test {
147148
use crate::Rng;
148149

149150
#[test]
150-
#[cfg(feature="serde1")]
151+
#[cfg(feature = "serde1")]
151152
fn test_serializing_deserializing_bernoulli() {
152153
let coin_flip = Bernoulli::new(0.5).unwrap();
153-
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
154+
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
154155

155156
assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
156157
}

src/distributions/float.rs

+39-35
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
//! Basic floating-point number distributions
1010
11-
use crate::distributions::utils::FloatSIMDUtils;
11+
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
1212
use crate::distributions::{Distribution, Standard};
1313
use crate::Rng;
1414
use core::mem;
15-
#[cfg(feature = "simd_support")] use packed_simd::*;
15+
#[cfg(feature = "simd_support")] use core::simd::*;
1616

1717
#[cfg(feature = "serde1")]
1818
use serde::{Serialize, Deserialize};
@@ -99,7 +99,7 @@ macro_rules! float_impls {
9999
// The exponent is encoded using an offset-binary representation
100100
let exponent_bits: $u_scalar =
101101
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
102-
$ty::from_bits(self | exponent_bits)
102+
$ty::from_bits(self | $uty::splat(exponent_bits))
103103
}
104104
}
105105

@@ -108,13 +108,13 @@ macro_rules! float_impls {
108108
// Multiply-based method; 24/53 random bits; [0, 1) interval.
109109
// We use the most significant bits because for simple RNGs
110110
// those are usually more random.
111-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
111+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
112112
let precision = $fraction_bits + 1;
113113
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
114114

115115
let value: $uty = rng.gen();
116-
let value = value >> (float_size - precision);
117-
scale * $ty::cast_from_int(value)
116+
let value = value >> $uty::splat(float_size - precision);
117+
$ty::splat(scale) * $ty::cast_from_int(value)
118118
}
119119
}
120120

@@ -123,14 +123,14 @@ macro_rules! float_impls {
123123
// Multiply-based method; 24/53 random bits; (0, 1] interval.
124124
// We use the most significant bits because for simple RNGs
125125
// those are usually more random.
126-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
126+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
127127
let precision = $fraction_bits + 1;
128128
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
129129

130130
let value: $uty = rng.gen();
131-
let value = value >> (float_size - precision);
131+
let value = value >> $uty::splat(float_size - precision);
132132
// Add 1 to shift up; will not overflow because of right-shift:
133-
scale * $ty::cast_from_int(value + 1)
133+
$ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1))
134134
}
135135
}
136136

@@ -140,11 +140,11 @@ macro_rules! float_impls {
140140
// We use the most significant bits because for simple RNGs
141141
// those are usually more random.
142142
use core::$f_scalar::EPSILON;
143-
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
143+
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
144144

145145
let value: $uty = rng.gen();
146-
let fraction = value >> (float_size - $fraction_bits);
147-
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
146+
let fraction = value >> $uty::splat(float_size - $fraction_bits);
147+
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
148148
}
149149
}
150150
}
@@ -169,10 +169,10 @@ float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
169169
#[cfg(feature = "simd_support")]
170170
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
171171

172-
173172
#[cfg(test)]
174173
mod tests {
175174
use super::*;
175+
use crate::distributions::utils::FloatAsSIMD;
176176
use crate::rngs::mock::StepRng;
177177

178178
const EPSILON32: f32 = ::core::f32::EPSILON;
@@ -182,29 +182,31 @@ mod tests {
182182
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
183183
#[test]
184184
fn $fnn() {
185+
let two = $ty::splat(2.0);
186+
185187
// Standard
186188
let mut zeros = StepRng::new(0, 0);
187189
assert_eq!(zeros.gen::<$ty>(), $ZERO);
188190
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
189-
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
191+
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
190192
let mut max = StepRng::new(!0, 0);
191-
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
193+
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);
192194

193195
// OpenClosed01
194196
let mut zeros = StepRng::new(0, 0);
195-
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
197+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
196198
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
197199
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
198200
let mut max = StepRng::new(!0, 0);
199-
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
201+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));
200202

201203
// Open01
202204
let mut zeros = StepRng::new(0, 0);
203-
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
205+
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
204206
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
205-
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
207+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
206208
let mut max = StepRng::new(!0, 0);
207-
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
209+
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
208210
}
209211
};
210212
}
@@ -222,29 +224,31 @@ mod tests {
222224
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
223225
#[test]
224226
fn $fnn() {
227+
let two = $ty::splat(2.0);
228+
225229
// Standard
226230
let mut zeros = StepRng::new(0, 0);
227231
assert_eq!(zeros.gen::<$ty>(), $ZERO);
228232
let mut one = StepRng::new(1 << 11, 0);
229-
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
233+
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
230234
let mut max = StepRng::new(!0, 0);
231-
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
235+
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);
232236

233237
// OpenClosed01
234238
let mut zeros = StepRng::new(0, 0);
235-
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
239+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
236240
let mut one = StepRng::new(1 << 11, 0);
237241
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
238242
let mut max = StepRng::new(!0, 0);
239-
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
243+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));
240244

241245
// Open01
242246
let mut zeros = StepRng::new(0, 0);
243-
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
247+
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
244248
let mut one = StepRng::new(1 << 12, 0);
245-
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
249+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
246250
let mut max = StepRng::new(!0, 0);
247-
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
251+
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
248252
}
249253
};
250254
}
@@ -296,16 +300,16 @@ mod tests {
296300
// non-SIMD types; we assume this pattern continues across all
297301
// SIMD types.
298302

299-
test_samples(&Standard, f32x2::new(0.0, 0.0), &[
300-
f32x2::new(0.0035963655, 0.7346052),
301-
f32x2::new(0.09778172, 0.20298547),
302-
f32x2::new(0.34296435, 0.81664366),
303+
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[
304+
f32x2::from([0.0035963655, 0.7346052]),
305+
f32x2::from([0.09778172, 0.20298547]),
306+
f32x2::from([0.34296435, 0.81664366]),
303307
]);
304308

305-
test_samples(&Standard, f64x2::new(0.0, 0.0), &[
306-
f64x2::new(0.7346051961657583, 0.20298547462974248),
307-
f64x2::new(0.8166436635290655, 0.7423708925400552),
308-
f64x2::new(0.16387782224016323, 0.9087068770169618),
309+
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[
310+
f64x2::from([0.7346051961657583, 0.20298547462974248]),
311+
f64x2::from([0.8166436635290655, 0.7423708925400552]),
312+
f64x2::from([0.16387782224016323, 0.9087068770169618]),
309313
]);
310314
}
311315
}

0 commit comments

Comments
 (0)