Skip to content

Commit 950c0af

Browse files
authored
Merge pull request #523 from pitdicker/simd_support_basic
Add basic SIMD support
2 parents 3af227a + 5c948fe commit 950c0af

File tree

7 files changed

+344
-149
lines changed

7 files changed

+344
-149
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ appveyor = { repository = "alexcrichton/rand" }
1919

2020
[features]
2121
default = ["std" ] # without "std" rand uses libcore
22-
nightly = ["i128_support"] # enables all features requiring nightly rust
22+
nightly = ["i128_support", "simd_support"] # enables all features requiring nightly rust
2323
std = ["rand_core/std", "alloc", "libc", "winapi", "cloudabi", "fuchsia-zircon"]
2424
alloc = ["rand_core/alloc"] # enables Vec and Box support (without std)
2525
i128_support = [] # enables i128 and u128 support
26+
simd_support = [] # enables SIMD support
2627
serde1 = ["serde", "serde_derive", "rand_core/serde1"] # enables serialization for PRNGs
2728

2829
[workspace]

src/distributions/float.rs

+110-57
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
use core::mem;
1414
use Rng;
1515
use distributions::{Distribution, Standard};
16+
use distributions::utils::CastFromInt;
17+
#[cfg(feature="simd_support")]
18+
use core::simd::*;
1619

1720
/// A distribution to sample floating point numbers uniformly in the half-open
1821
/// interval `(0, 1]`, i.e. including 1 but not 0.
@@ -83,15 +86,16 @@ pub(crate) trait IntoFloat {
8386
}
8487

8588
macro_rules! float_impls {
86-
($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => {
89+
($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty,
90+
$fraction_bits:expr, $exponent_bias:expr) => {
8791
impl IntoFloat for $uty {
8892
type F = $ty;
8993
#[inline(always)]
9094
fn into_float_with_exponent(self, exponent: i32) -> $ty {
9195
// The exponent is encoded using an offset-binary representation
92-
let exponent_bits =
93-
(($exponent_bias + exponent) as $uty) << $fraction_bits;
94-
unsafe { mem::transmute(self | exponent_bits) }
96+
let exponent_bits: $u_scalar =
97+
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
98+
$ty::from_bits(self | exponent_bits)
9599
}
96100
}
97101

@@ -100,12 +104,13 @@ macro_rules! float_impls {
100104
// Multiply-based method; 24/53 random bits; [0, 1) interval.
101105
// We use the most significant bits because for simple RNGs
102106
// those are usually more random.
103-
let float_size = mem::size_of::<$ty>() * 8;
107+
let float_size = mem::size_of::<$f_scalar>() * 8;
104108
let precision = $fraction_bits + 1;
105-
let scale = 1.0 / ((1 as $uty << precision) as $ty);
109+
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
106110

107111
let value: $uty = rng.gen();
108-
scale * (value >> (float_size - precision)) as $ty
112+
let value = value >> (float_size - precision);
113+
scale * $ty::cast_from_int(value)
109114
}
110115
}
111116

@@ -114,14 +119,14 @@ macro_rules! float_impls {
114119
// Multiply-based method; 24/53 random bits; (0, 1] interval.
115120
// We use the most significant bits because for simple RNGs
116121
// those are usually more random.
117-
let float_size = mem::size_of::<$ty>() * 8;
122+
let float_size = mem::size_of::<$f_scalar>() * 8;
118123
let precision = $fraction_bits + 1;
119-
let scale = 1.0 / ((1 as $uty << precision) as $ty);
124+
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
120125

121126
let value: $uty = rng.gen();
122127
let value = value >> (float_size - precision);
123128
// Add 1 to shift up; will not overflow because of right-shift:
124-
scale * (value + 1) as $ty
129+
scale * $ty::cast_from_int(value + 1)
125130
}
126131
}
127132

@@ -130,8 +135,8 @@ macro_rules! float_impls {
130135
// Transmute-based method; 23/52 random bits; (0, 1) interval.
131136
// We use the most significant bits because for simple RNGs
132137
// those are usually more random.
133-
const EPSILON: $ty = 1.0 / (1u64 << $fraction_bits) as $ty;
134-
let float_size = mem::size_of::<$ty>() * 8;
138+
use core::$f_scalar::EPSILON;
139+
let float_size = mem::size_of::<$f_scalar>() * 8;
135140

136141
let value: $uty = rng.gen();
137142
let fraction = value >> (float_size - $fraction_bits);
@@ -140,67 +145,115 @@ macro_rules! float_impls {
140145
}
141146
}
142147
}
143-
float_impls! { f32, u32, 23, 127 }
144-
float_impls! { f64, u64, 52, 1023 }
148+
149+
float_impls! { f32, u32, f32, u32, 23, 127 }
150+
float_impls! { f64, u64, f64, u64, 52, 1023 }
151+
152+
#[cfg(feature="simd_support")]
153+
float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
154+
#[cfg(feature="simd_support")]
155+
float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
156+
#[cfg(feature="simd_support")]
157+
float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
158+
#[cfg(feature="simd_support")]
159+
float_impls! { f32x16, u32x16, f32, u32, 23, 127 }
160+
161+
#[cfg(feature="simd_support")]
162+
float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
163+
#[cfg(feature="simd_support")]
164+
float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
165+
#[cfg(feature="simd_support")]
166+
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
145167

146168

147169
#[cfg(test)]
148170
mod tests {
149171
use Rng;
150172
use distributions::{Open01, OpenClosed01};
151173
use rngs::mock::StepRng;
174+
#[cfg(feature="simd_support")]
175+
use core::simd::*;
152176

153177
const EPSILON32: f32 = ::core::f32::EPSILON;
154178
const EPSILON64: f64 = ::core::f64::EPSILON;
155179

156-
#[test]
157-
fn standard_fp_edge_cases() {
158-
let mut zeros = StepRng::new(0, 0);
159-
assert_eq!(zeros.gen::<f32>(), 0.0);
160-
assert_eq!(zeros.gen::<f64>(), 0.0);
161-
162-
let mut one32 = StepRng::new(1 << 8, 0);
163-
assert_eq!(one32.gen::<f32>(), EPSILON32 / 2.0);
164-
165-
let mut one64 = StepRng::new(1 << 11, 0);
166-
assert_eq!(one64.gen::<f64>(), EPSILON64 / 2.0);
167-
168-
let mut max = StepRng::new(!0, 0);
169-
assert_eq!(max.gen::<f32>(), 1.0 - EPSILON32 / 2.0);
170-
assert_eq!(max.gen::<f64>(), 1.0 - EPSILON64 / 2.0);
171-
}
180+
macro_rules! test_f32 {
181+
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
182+
#[test]
183+
fn $fnn() {
184+
// Standard
185+
let mut zeros = StepRng::new(0, 0);
186+
assert_eq!(zeros.gen::<$ty>(), $ZERO);
187+
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
188+
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
189+
let mut max = StepRng::new(!0, 0);
190+
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
172191

173-
#[test]
174-
fn openclosed01_edge_cases() {
175-
let mut zeros = StepRng::new(0, 0);
176-
assert_eq!(zeros.sample::<f32, _>(OpenClosed01), 0.0 + EPSILON32 / 2.0);
177-
assert_eq!(zeros.sample::<f64, _>(OpenClosed01), 0.0 + EPSILON64 / 2.0);
192+
// OpenClosed01
193+
let mut zeros = StepRng::new(0, 0);
194+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
195+
0.0 + $EPSILON / 2.0);
196+
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
197+
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
198+
let mut max = StepRng::new(!0, 0);
199+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
178200

179-
let mut one32 = StepRng::new(1 << 8, 0);
180-
assert_eq!(one32.sample::<f32, _>(OpenClosed01), EPSILON32);
181-
182-
let mut one64 = StepRng::new(1 << 11, 0);
183-
assert_eq!(one64.sample::<f64, _>(OpenClosed01), EPSILON64);
184-
185-
let mut max = StepRng::new(!0, 0);
186-
assert_eq!(max.sample::<f32, _>(OpenClosed01), 1.0);
187-
assert_eq!(max.sample::<f64, _>(OpenClosed01), 1.0);
201+
// Open01
202+
let mut zeros = StepRng::new(0, 0);
203+
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
204+
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
205+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
206+
let mut max = StepRng::new(!0, 0);
207+
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
208+
}
209+
}
188210
}
211+
test_f32! { f32_edge_cases, f32, 0.0, EPSILON32 }
212+
#[cfg(feature="simd_support")]
213+
test_f32! { f32x2_edge_cases, f32x2, f32x2::splat(0.0), f32x2::splat(EPSILON32) }
214+
#[cfg(feature="simd_support")]
215+
test_f32! { f32x4_edge_cases, f32x4, f32x4::splat(0.0), f32x4::splat(EPSILON32) }
216+
#[cfg(feature="simd_support")]
217+
test_f32! { f32x8_edge_cases, f32x8, f32x8::splat(0.0), f32x8::splat(EPSILON32) }
218+
#[cfg(feature="simd_support")]
219+
test_f32! { f32x16_edge_cases, f32x16, f32x16::splat(0.0), f32x16::splat(EPSILON32) }
189220

190-
#[test]
191-
fn open01_edge_cases() {
192-
let mut zeros = StepRng::new(0, 0);
193-
assert_eq!(zeros.sample::<f32, _>(Open01), 0.0 + EPSILON32 / 2.0);
194-
assert_eq!(zeros.sample::<f64, _>(Open01), 0.0 + EPSILON64 / 2.0);
221+
macro_rules! test_f64 {
222+
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
223+
#[test]
224+
fn $fnn() {
225+
// Standard
226+
let mut zeros = StepRng::new(0, 0);
227+
assert_eq!(zeros.gen::<$ty>(), $ZERO);
228+
let mut one = StepRng::new(1 << 11, 0);
229+
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
230+
let mut max = StepRng::new(!0, 0);
231+
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
195232

196-
let mut one32 = StepRng::new(1 << 9, 0);
197-
assert_eq!(one32.sample::<f32, _>(Open01), EPSILON32 / 2.0 * 3.0);
233+
// OpenClosed01
234+
let mut zeros = StepRng::new(0, 0);
235+
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
236+
0.0 + $EPSILON / 2.0);
237+
let mut one = StepRng::new(1 << 11, 0);
238+
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
239+
let mut max = StepRng::new(!0, 0);
240+
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
198241

199-
let mut one64 = StepRng::new(1 << 12, 0);
200-
assert_eq!(one64.sample::<f64, _>(Open01), EPSILON64 / 2.0 * 3.0);
201-
202-
let mut max = StepRng::new(!0, 0);
203-
assert_eq!(max.sample::<f32, _>(Open01), 1.0 - EPSILON32 / 2.0);
204-
assert_eq!(max.sample::<f64, _>(Open01), 1.0 - EPSILON64 / 2.0);
242+
// Open01
243+
let mut zeros = StepRng::new(0, 0);
244+
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
245+
let mut one = StepRng::new(1 << 12, 0);
246+
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
247+
let mut max = StepRng::new(!0, 0);
248+
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
249+
}
250+
}
205251
}
252+
test_f64! { f64_edge_cases, f64, 0.0, EPSILON64 }
253+
#[cfg(feature="simd_support")]
254+
test_f64! { f64x2_edge_cases, f64x2, f64x2::splat(0.0), f64x2::splat(EPSILON64) }
255+
#[cfg(feature="simd_support")]
256+
test_f64! { f64x4_edge_cases, f64x4, f64x4::splat(0.0), f64x4::splat(EPSILON64) }
257+
#[cfg(feature="simd_support")]
258+
test_f64! { f64x8_edge_cases, f64x8, f64x8::splat(0.0), f64x8::splat(EPSILON64) }
206259
}

src/distributions/integer.rs

+35
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
1313
use {Rng};
1414
use distributions::{Distribution, Standard};
15+
#[cfg(feature="simd_support")]
16+
use core::simd::*;
1517

1618
impl Distribution<u8> for Standard {
1719
#[inline]
@@ -84,6 +86,39 @@ impl_int_from_uint! { i64, u64 }
8486
#[cfg(feature = "i128_support")] impl_int_from_uint! { i128, u128 }
8587
impl_int_from_uint! { isize, usize }
8688

89+
#[cfg(feature="simd_support")]
90+
macro_rules! simd_impl {
91+
($bits:expr,) => {};
92+
($bits:expr, $ty:ty, $($ty_more:ty,)*) => {
93+
simd_impl!($bits, $($ty_more,)*);
94+
95+
impl Distribution<$ty> for Standard {
96+
#[inline]
97+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
98+
let mut vec = Default::default();
99+
unsafe {
100+
let ptr = &mut vec;
101+
let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]);
102+
rng.fill_bytes(b_ptr);
103+
}
104+
vec
105+
}
106+
}
107+
}
108+
}
109+
110+
#[cfg(feature="simd_support")]
111+
simd_impl!(16, u8x2, i8x2,);
112+
#[cfg(feature="simd_support")]
113+
simd_impl!(32, u8x4, i8x4, u16x2, i16x2,);
114+
#[cfg(feature="simd_support")]
115+
simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,);
116+
#[cfg(feature="simd_support")]
117+
simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,);
118+
#[cfg(feature="simd_support")]
119+
simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,);
120+
#[cfg(feature="simd_support")]
121+
simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,);
87122

88123
#[cfg(test)]
89124
mod tests {

src/distributions/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ mod integer;
215215
#[cfg(feature="std")]
216216
mod log_gamma;
217217
mod other;
218+
mod utils;
218219
#[cfg(feature="std")]
219220
mod ziggurat_tables;
220221
#[cfg(feature="std")]

0 commit comments

Comments
 (0)