@@ -37,6 +37,31 @@ pub struct Bernoulli {
37
37
p_int : u64 ,
38
38
}
39
39
40
+ // To sample from the Bernoulli distribution we use a method that compares a
41
+ // random `u64` value `v < (p * 2^64)`.
42
+ //
43
+ // If `p == 1.0`, the integer `v` to compare against can not represented as a
44
+ // `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64).
45
+ // Note that value of `p < 1.0` can never result in `u64::MAX`, because an
46
+ // `f64` only has 53 bits of precision, and the next largest value of `p` will
47
+ // result in `2^64 - 2048`.
48
+ //
49
+ // Also there is a 100% theoretical concern: if someone consistenly wants to
50
+ // generate `true` using the Bernoulli distribution (i.e. by using a probability
51
+ // of `1.0`), just using `u64::MAX` is not enough. On average it would return
52
+ // false once every 2^64 iterations. Some people apparently care about this
53
+ // case.
54
+ //
55
+ // That is why we special-case `u64::MAX` to always return `true`, without using
56
+ // the RNG, and pay the performance price for all uses that *are* reasonable.
57
+ // Luckily, if `new()` and `sample` are close, the compiler can optimize out the
58
+ // extra check.
59
+ const ALWAYS_TRUE : u64 = :: core:: u64:: MAX ;
60
+
61
+ // This is just `2.0.powi(64)`, but written this way because it is not available
62
+ // in `no_std` mode.
63
+ const SCALE : f64 = 2.0 * ( 1u64 << 63 ) as f64 ;
64
+
40
65
impl Bernoulli {
41
66
/// Construct a new `Bernoulli` with the given probability of success `p`.
42
67
///
@@ -54,18 +79,11 @@ impl Bernoulli {
54
79
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
55
80
#[ inline]
56
81
pub fn new ( p : f64 ) -> Bernoulli {
57
- assert ! ( ( p >= 0.0 ) & ( p <= 1.0 ) , "Bernoulli::new not called with 0 <= p <= 0" ) ;
58
- // Technically, this should be 2^64 or `u64::MAX + 1` because we compare
59
- // using `<` when sampling. However, `u64::MAX` rounds to an `f64`
60
- // larger than `u64::MAX` anyway.
61
- const MAX_P_INT : f64 = :: core:: u64:: MAX as f64 ;
62
- let p_int = if p < 1.0 {
63
- ( p * MAX_P_INT ) as u64
64
- } else {
65
- // Avoid overflow: `MAX_P_INT` cannot be represented as u64.
66
- :: core:: u64:: MAX
67
- } ;
68
- Bernoulli { p_int }
82
+ if p < 0.0 || p >= 1.0 {
83
+ if p == 1.0 { return Bernoulli { p_int : ALWAYS_TRUE } }
84
+ panic ! ( "Bernoulli::new not called with 0.0 <= p <= 1.0" ) ;
85
+ }
86
+ Bernoulli { p_int : ( p * SCALE ) as u64 }
69
87
}
70
88
71
89
/// Construct a new `Bernoulli` with the probability of success of
@@ -85,7 +103,6 @@ impl Bernoulli {
85
103
if numerator == denominator {
86
104
return Bernoulli { p_int : :: core:: u64:: MAX }
87
105
}
88
- const SCALE : f64 = 2.0 * ( 1u64 << 63 ) as f64 ;
89
106
let p_int = ( ( numerator as f64 / denominator as f64 ) * SCALE ) as u64 ;
90
107
Bernoulli { p_int }
91
108
}
@@ -95,11 +112,9 @@ impl Distribution<bool> for Bernoulli {
95
112
#[ inline]
96
113
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> bool {
97
114
// Make sure to always return true for p = 1.0.
98
- if self . p_int == :: core:: u64:: MAX {
99
- return true ;
100
- }
101
- let r: u64 = rng. gen ( ) ;
102
- r < self . p_int
115
+ if self . p_int == ALWAYS_TRUE { return true ; }
116
+ let v: u64 = rng. gen ( ) ;
117
+ v < self . p_int
103
118
}
104
119
}
105
120
0 commit comments