Skip to content

Commit ec3d7ef

Browse files
authored
Merge pull request #500 from pitdicker/optimize_bernoulli_new
Optimize Bernoulli::new
2 parents 7ae36ee + 6044cc8 commit ec3d7ef

File tree

3 files changed

+34
-43
lines changed

3 files changed

+34
-43
lines changed

benches/misc.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ fn misc_gen_ratio_var(b: &mut Bencher) {
6363
#[bench]
6464
fn misc_bernoulli_const(b: &mut Bencher) {
6565
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
66-
let d = rand::distributions::Bernoulli::new(0.18);
6766
b.iter(|| {
68-
// Can be evaluated at compile time.
67+
let d = rand::distributions::Bernoulli::new(0.18);
6968
let mut accum = true;
7069
for _ in 0..::RAND_BENCH_N {
7170
accum ^= rng.sample(d);

src/distributions/bernoulli.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ pub struct Bernoulli {
3737
p_int: u64,
3838
}
3939

40+
// To sample from the Bernoulli distribution we use a method that compares a
41+
// random `u64` value `v < (p * 2^64)`.
42+
//
43+
// If `p == 1.0`, the integer `v` to compare against can not represented as a
44+
// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64).
45+
// Note that value of `p < 1.0` can never result in `u64::MAX`, because an
46+
// `f64` only has 53 bits of precision, and the next largest value of `p` will
47+
// result in `2^64 - 2048`.
48+
//
49+
// Also there is a 100% theoretical concern: if someone consistenly wants to
50+
// generate `true` using the Bernoulli distribution (i.e. by using a probability
51+
// of `1.0`), just using `u64::MAX` is not enough. On average it would return
52+
// false once every 2^64 iterations. Some people apparently care about this
53+
// case.
54+
//
55+
// That is why we special-case `u64::MAX` to always return `true`, without using
56+
// the RNG, and pay the performance price for all uses that *are* reasonable.
57+
// Luckily, if `new()` and `sample` are close, the compiler can optimize out the
58+
// extra check.
59+
const ALWAYS_TRUE: u64 = ::core::u64::MAX;
60+
61+
// This is just `2.0.powi(64)`, but written this way because it is not available
62+
// in `no_std` mode.
63+
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
64+
4065
impl Bernoulli {
4166
/// Construct a new `Bernoulli` with the given probability of success `p`.
4267
///
@@ -54,18 +79,11 @@ impl Bernoulli {
5479
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
5580
#[inline]
5681
pub fn new(p: f64) -> Bernoulli {
57-
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
58-
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare
59-
// using `<` when sampling. However, `u64::MAX` rounds to an `f64`
60-
// larger than `u64::MAX` anyway.
61-
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
62-
let p_int = if p < 1.0 {
63-
(p * MAX_P_INT) as u64
64-
} else {
65-
// Avoid overflow: `MAX_P_INT` cannot be represented as u64.
66-
::core::u64::MAX
67-
};
68-
Bernoulli { p_int }
82+
if p < 0.0 || p >= 1.0 {
83+
if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } }
84+
panic!("Bernoulli::new not called with 0.0 <= p <= 1.0");
85+
}
86+
Bernoulli { p_int: (p * SCALE) as u64 }
6987
}
7088

7189
/// Construct a new `Bernoulli` with the probability of success of
@@ -85,7 +103,6 @@ impl Bernoulli {
85103
if numerator == denominator {
86104
return Bernoulli { p_int: ::core::u64::MAX }
87105
}
88-
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
89106
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
90107
Bernoulli { p_int }
91108
}
@@ -95,11 +112,9 @@ impl Distribution<bool> for Bernoulli {
95112
#[inline]
96113
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
97114
// Make sure to always return true for p = 1.0.
98-
if self.p_int == ::core::u64::MAX {
99-
return true;
100-
}
101-
let r: u64 = rng.gen();
102-
r < self.p_int
115+
if self.p_int == ALWAYS_TRUE { return true; }
116+
let v: u64 = rng.gen();
117+
v < self.p_int
103118
}
104119
}
105120

tests/bool.rs

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)