Skip to content

Commit d4a2945

Browse files
authored
Merge pull request #1289 from dhardy/uniform-float
Uniform float improvements
2 parents 1464b88 + 026292d commit d4a2945

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,8 @@ harness = false
9494
name = "shuffle"
9595
path = "benches/shuffle.rs"
9696
harness = false
97+
98+
[[bench]]
99+
name = "uniform_float"
100+
path = "benches/uniform_float.rs"
101+
harness = false

benches/uniform_float.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright 2023 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
//! Implement benchmarks for uniform distributions over FP types
10+
//!
11+
//! Sampling methods compared:
12+
//!
13+
//! - sample: current method: (x12 - 1.0) * (b - a) + a
14+
15+
use core::time::Duration;
16+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
17+
use rand::distributions::uniform::{SampleUniform, Uniform, UniformSampler};
18+
use rand::prelude::*;
19+
use rand_chacha::ChaCha8Rng;
20+
use rand_pcg::{Pcg32, Pcg64};
21+
22+
const WARM_UP_TIME: Duration = Duration::from_millis(1000);
23+
const MEASUREMENT_TIME: Duration = Duration::from_secs(3);
24+
const SAMPLE_SIZE: usize = 100_000;
25+
const N_RESAMPLES: usize = 10_000;
26+
27+
macro_rules! single_random {
28+
($R:ty, $T:ty, $g:expr) => {
29+
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
30+
let mut rng = <$R>::from_entropy();
31+
let (mut low, mut high);
32+
loop {
33+
low = <$T>::from_bits(rng.gen());
34+
high = <$T>::from_bits(rng.gen());
35+
if (low < high) && (high - low).is_normal() {
36+
break;
37+
}
38+
}
39+
40+
b.iter(|| <$T as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng));
41+
});
42+
};
43+
44+
($c:expr, $T:ty) => {{
45+
let mut g = $c.benchmark_group("uniform_single");
46+
g.sample_size(SAMPLE_SIZE);
47+
g.warm_up_time(WARM_UP_TIME);
48+
g.measurement_time(MEASUREMENT_TIME);
49+
g.nresamples(N_RESAMPLES);
50+
single_random!(SmallRng, $T, g);
51+
single_random!(ChaCha8Rng, $T, g);
52+
single_random!(Pcg32, $T, g);
53+
single_random!(Pcg64, $T, g);
54+
g.finish();
55+
}};
56+
}
57+
58+
fn single_random(c: &mut Criterion) {
59+
single_random!(c, f32);
60+
single_random!(c, f64);
61+
}
62+
63+
macro_rules! distr_random {
64+
($R:ty, $T:ty, $g:expr) => {
65+
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
66+
let mut rng = <$R>::from_entropy();
67+
let dist = loop {
68+
let low = <$T>::from_bits(rng.gen());
69+
let high = <$T>::from_bits(rng.gen());
70+
if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) {
71+
break dist;
72+
}
73+
};
74+
75+
b.iter(|| dist.sample(&mut rng));
76+
});
77+
};
78+
79+
($c:expr, $T:ty) => {{
80+
let mut g = $c.benchmark_group("uniform_distribution");
81+
g.sample_size(SAMPLE_SIZE);
82+
g.warm_up_time(WARM_UP_TIME);
83+
g.measurement_time(MEASUREMENT_TIME);
84+
g.nresamples(N_RESAMPLES);
85+
distr_random!(SmallRng, $T, g);
86+
distr_random!(ChaCha8Rng, $T, g);
87+
distr_random!(Pcg32, $T, g);
88+
distr_random!(Pcg64, $T, g);
89+
g.finish();
90+
}};
91+
}
92+
93+
fn distr_random(c: &mut Criterion) {
94+
distr_random!(c, f32);
95+
distr_random!(c, f64);
96+
}
97+
98+
criterion_group! {
99+
name = benches;
100+
config = Criterion::default();
101+
targets = single_random, distr_random
102+
}
103+
criterion_main!(benches);

src/distributions/uniform.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,38 @@ macro_rules! uniform_float_impl {
10381038
}
10391039
}
10401040
}
1041+
1042+
#[inline]
1043+
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
1044+
where
1045+
B1: SampleBorrow<Self::X> + Sized,
1046+
B2: SampleBorrow<Self::X> + Sized,
1047+
{
1048+
let low = *low_b.borrow();
1049+
let high = *high_b.borrow();
1050+
#[cfg(debug_assertions)]
1051+
if !low.all_finite() || !high.all_finite() {
1052+
return Err(Error::NonFinite);
1053+
}
1054+
if !low.all_le(high) {
1055+
return Err(Error::EmptyRange);
1056+
}
1057+
let scale = high - low;
1058+
if !scale.all_finite() {
1059+
return Err(Error::NonFinite);
1060+
}
1061+
1062+
// Generate a value in the range [1, 2)
1063+
let value1_2 =
1064+
(rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
1065+
1066+
// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
1067+
let value0_1 = value1_2 - <$ty>::splat(1.0);
1068+
1069+
// Doing multiply before addition allows some architectures
1070+
// to use a single instruction.
1071+
Ok(value0_1 * scale + low)
1072+
}
10411073
}
10421074
};
10431075
}
@@ -1380,6 +1412,9 @@ mod tests {
13801412
let v = <$ty as SampleUniform>::Sampler
13811413
::sample_single(low, high, &mut rng).unwrap().extract(lane);
13821414
assert!(low_scalar <= v && v < high_scalar);
1415+
let v = <$ty as SampleUniform>::Sampler
1416+
::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane);
1417+
assert!(low_scalar <= v && v <= high_scalar);
13831418
}
13841419

13851420
assert_eq!(
@@ -1392,8 +1427,19 @@ mod tests {
13921427
assert_eq!(<$ty as SampleUniform>::Sampler
13931428
::sample_single(low, high, &mut zero_rng).unwrap()
13941429
.extract(lane), low_scalar);
1430+
assert_eq!(<$ty as SampleUniform>::Sampler
1431+
::sample_single_inclusive(low, high, &mut zero_rng).unwrap()
1432+
.extract(lane), low_scalar);
1433+
13951434
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
13961435
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
1436+
// sample_single cannot cope with max_rng:
1437+
// assert!(<$ty as SampleUniform>::Sampler
1438+
// ::sample_single(low, high, &mut max_rng).unwrap()
1439+
// .extract(lane) < high_scalar);
1440+
assert!(<$ty as SampleUniform>::Sampler
1441+
::sample_single_inclusive(low, high, &mut max_rng).unwrap()
1442+
.extract(lane) <= high_scalar);
13971443

13981444
// Don't run this test for really tiny differences between high and low
13991445
// since for those rounding might result in selecting high for a very

0 commit comments

Comments
 (0)