Skip to content

Commit 2ced48c

Browse files
committed
Add Poisson and binomial distributions
1 parent 8245d5f commit 2ced48c

12 files changed

+1681
-439
lines changed

src/distributions/binomial.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! The binomial distribution.
12+
13+
use Rng;
14+
use distributions::Distribution;
15+
use distributions::log_gamma::log_gamma;
16+
use std::f64::consts::PI;
17+
18+
/// The binomial distribution `Binomial(n, p)`.
19+
///
20+
/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
21+
///
22+
/// # Example
23+
///
24+
/// ```rust
25+
/// use rand::distributions::{Binomial, Distribution};
26+
///
27+
/// let bin = Binomial::new(20, 0.3);
28+
/// let v = bin.sample(&mut rand::thread_rng());
29+
/// println!("{} is from a binomial distribution", v);
30+
/// ```
31+
#[derive(Clone, Copy, Debug)]
32+
pub struct Binomial {
33+
n: u64, // number of trials
34+
p: f64, // probability of success
35+
}
36+
37+
impl Binomial {
38+
/// Construct a new `Binomial` with the given shape parameters
39+
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
40+
pub fn new(n: u64, p: f64) -> Binomial {
41+
assert!(p > 0.0, "Binomial::new called with `p` <= 0");
42+
assert!(p < 1.0, "Binomial::new called with `p` >= 1");
43+
Binomial { n: n, p: p }
44+
}
45+
}
46+
47+
impl Distribution<u64> for Binomial {
48+
fn sample<R: Rng>(&self, rng: &mut R) -> u64 {
49+
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
50+
// switch p so that it is less than 0.5 - this allows for lower expected values
51+
// we will just invert the result at the end
52+
let p = if self.p <= 0.5 {
53+
self.p
54+
} else {
55+
1.0 - self.p
56+
};
57+
58+
// expected value of the sample
59+
let expected = self.n as f64 * p;
60+
61+
let result =
62+
// for low expected values we just simulate n drawings
63+
if expected < 25.0 {
64+
let mut lresult = 0.0;
65+
for _ in 0 .. self.n {
66+
if rng.gen::<f64>() < p {
67+
lresult += 1.0;
68+
}
69+
}
70+
lresult
71+
}
72+
// high expected value - do the rejection method
73+
else {
74+
// prepare some cached values
75+
let float_n = self.n as f64;
76+
let ln_fact_n = log_gamma(float_n + 1.0);
77+
let pc = 1.0 - p;
78+
let log_p = p.ln();
79+
let log_pc = pc.ln();
80+
let sq = (expected * (2.0 * pc)).sqrt();
81+
82+
let mut lresult;
83+
84+
loop {
85+
let mut comp_dev: f64;
86+
// we use the lorentzian distribution as the comparison distribution
87+
// f(x) ~ 1/(1+x/^2)
88+
loop {
89+
// draw from the lorentzian distribution
90+
comp_dev = (PI*rng.gen::<f64>()).tan();
91+
// shift the peak of the comparison ditribution
92+
lresult = expected + sq * comp_dev;
93+
// repeat the drawing until we are in the range of possible values
94+
if lresult >= 0.0 && lresult < float_n + 1.0 {
95+
break;
96+
}
97+
}
98+
99+
// the result should be discrete
100+
lresult = lresult.floor();
101+
102+
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
103+
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
104+
// this is the binomial probability divided by the comparison probability
105+
// we will generate a uniform random value and if it is larger than this,
106+
// we interpret it as a value falling out of the distribution and repeat
107+
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
108+
109+
if comparison_coeff >= rng.gen() {
110+
break;
111+
}
112+
}
113+
114+
lresult
115+
};
116+
117+
// invert the result for p < 0.5
118+
if p != self.p {
119+
self.n - result as u64
120+
} else {
121+
result as u64
122+
}
123+
}
124+
}
125+
126+
#[cfg(test)]
127+
mod test {
128+
use distributions::Distribution;
129+
use super::Binomial;
130+
131+
#[test]
132+
fn test_binomial() {
133+
let binomial = Binomial::new(150, 0.1);
134+
let mut rng = ::test::rng(123);
135+
let mut sum = 0;
136+
for _ in 0..1000 {
137+
sum += binomial.sample(&mut rng);
138+
}
139+
let avg = (sum as f64) / 1000.0;
140+
println!("Binomial average: {}", avg);
141+
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
142+
}
143+
144+
#[test]
145+
#[should_panic]
146+
#[cfg_attr(target_env = "msvc", ignore)]
147+
fn test_binomial_invalid_lambda_zero() {
148+
Binomial::new(20, 0.0);
149+
}
150+
#[test]
151+
#[should_panic]
152+
#[cfg_attr(target_env = "msvc", ignore)]
153+
fn test_binomial_invalid_lambda_neg() {
154+
Binomial::new(20, -10.0);
155+
}
156+
}

src/distributions/exponential.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
//! The exponential distribution.
1212
13-
use {Rng};
13+
use Rng;
1414
use distributions::{ziggurat, ziggurat_tables, Distribution};
1515

1616
/// Samples floating-point numbers according to the exponential distribution,
@@ -52,10 +52,14 @@ impl Distribution<f64> for Exp1 {
5252
ziggurat_tables::ZIG_EXP_R - rng.gen::<f64>().ln()
5353
}
5454

55-
ziggurat(rng, false,
56-
&ziggurat_tables::ZIG_EXP_X,
57-
&ziggurat_tables::ZIG_EXP_F,
58-
pdf, zero_case)
55+
ziggurat(
56+
rng,
57+
false,
58+
&ziggurat_tables::ZIG_EXP_X,
59+
&ziggurat_tables::ZIG_EXP_F,
60+
pdf,
61+
zero_case,
62+
)
5963
}
6064
}
6165

@@ -76,7 +80,7 @@ impl Distribution<f64> for Exp1 {
7680
#[derive(Clone, Copy, Debug)]
7781
pub struct Exp {
7882
/// `lambda` stored as `1/lambda`, since this is what we scale by.
79-
lambda_inverse: f64
83+
lambda_inverse: f64,
8084
}
8185

8286
impl Exp {
@@ -85,7 +89,9 @@ impl Exp {
8589
#[inline]
8690
pub fn new(lambda: f64) -> Exp {
8791
assert!(lambda > 0.0, "Exp::new called with `lambda` <= 0");
88-
Exp { lambda_inverse: 1.0 / lambda }
92+
Exp {
93+
lambda_inverse: 1.0 / lambda,
94+
}
8995
}
9096
}
9197

src/distributions/float.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ macro_rules! float_impls {
6969
float_impls! { f32, u32, 23, 127, next_u32 }
7070
float_impls! { f64, u64, 52, 1023, next_u64 }
7171

72-
7372
#[cfg(test)]
7473
mod tests {
7574
use Rng;

src/distributions/gamma.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
use self::GammaRepr::*;
1616
use self::ChiSquaredRepr::*;
1717

18-
use {Rng};
18+
use Rng;
1919
use distributions::normal::StandardNormal;
2020
use distributions::{Distribution, Exp, Uniform};
2121

@@ -58,7 +58,7 @@ pub struct Gamma {
5858
enum GammaRepr {
5959
Large(GammaLargeShape),
6060
One(Exp),
61-
Small(GammaSmallShape)
61+
Small(GammaSmallShape),
6262
}
6363

6464
// These two helpers could be made public, but saving the
@@ -78,7 +78,7 @@ enum GammaRepr {
7878
#[derive(Clone, Copy, Debug)]
7979
struct GammaSmallShape {
8080
inv_shape: f64,
81-
large_shape: GammaLargeShape
81+
large_shape: GammaLargeShape,
8282
}
8383

8484
/// Gamma distribution where the shape parameter is larger than 1.
@@ -89,7 +89,7 @@ struct GammaSmallShape {
8989
struct GammaLargeShape {
9090
scale: f64,
9191
c: f64,
92-
d: f64
92+
d: f64,
9393
}
9494

9595
impl Gamma {
@@ -117,7 +117,7 @@ impl GammaSmallShape {
117117
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
118118
GammaSmallShape {
119119
inv_shape: 1. / shape,
120-
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
120+
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale),
121121
}
122122
}
123123
}
@@ -128,7 +128,7 @@ impl GammaLargeShape {
128128
GammaLargeShape {
129129
scale: scale,
130130
c: 1. / (9. * d).sqrt(),
131-
d: d
131+
d: d,
132132
}
133133
}
134134
}
@@ -154,17 +154,19 @@ impl Distribution<f64> for GammaLargeShape {
154154
loop {
155155
let x = rng.sample(StandardNormal);
156156
let v_cbrt = 1.0 + self.c * x;
157-
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
158-
continue
157+
if v_cbrt <= 0.0 {
158+
// a^3 <= 0 iff a <= 0
159+
continue;
159160
}
160161

161162
let v = v_cbrt * v_cbrt * v_cbrt;
162163
let u: f64 = rng.sample(Uniform);
163164

164165
let x_sqr = x * x;
165-
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
166-
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
167-
return self.d * v * self.scale
166+
if u < 1.0 - 0.0331 * x_sqr * x_sqr
167+
|| u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln())
168+
{
169+
return self.d * v * self.scale;
168170
}
169171
}
170172
}
@@ -222,7 +224,7 @@ impl Distribution<f64> for ChiSquared {
222224
let norm = rng.sample(StandardNormal);
223225
norm * norm
224226
}
225-
DoFAnythingElse(ref g) => g.sample(rng)
227+
DoFAnythingElse(ref g) => g.sample(rng),
226228
}
227229
}
228230
}
@@ -261,7 +263,7 @@ impl FisherF {
261263
FisherF {
262264
numer: ChiSquared::new(m),
263265
denom: ChiSquared::new(n),
264-
dof_ratio: n / m
266+
dof_ratio: n / m,
265267
}
266268
}
267269
}
@@ -286,7 +288,7 @@ impl Distribution<f64> for FisherF {
286288
#[derive(Clone, Copy, Debug)]
287289
pub struct StudentT {
288290
chi: ChiSquared,
289-
dof: f64
291+
dof: f64,
290292
}
291293

292294
impl StudentT {
@@ -296,7 +298,7 @@ impl StudentT {
296298
assert!(n > 0.0, "StudentT::new called with `n <= 0`");
297299
StudentT {
298300
chi: ChiSquared::new(n),
299-
dof: n
301+
dof: n,
300302
}
301303
}
302304
}
@@ -310,7 +312,7 @@ impl Distribution<f64> for StudentT {
310312
#[cfg(test)]
311313
mod test {
312314
use distributions::Distribution;
313-
use super::{ChiSquared, StudentT, FisherF};
315+
use super::{ChiSquared, FisherF, StudentT};
314316

315317
#[test]
316318
fn test_chi_squared_one() {

src/distributions/integer.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
use core::mem;
1414

15-
use {Rng};
15+
use Rng;
1616
use distributions::{Distribution, Uniform};
1717

1818
impl Distribution<isize> for Uniform {
@@ -109,24 +109,23 @@ impl Distribution<u128> for Uniform {
109109
}
110110
}
111111

112-
113112
#[cfg(test)]
114113
mod tests {
115114
use Rng;
116-
use distributions::{Uniform};
117-
115+
use distributions::Uniform;
116+
118117
#[test]
119118
fn test_integers() {
120119
let mut rng = ::test::rng(806);
121-
120+
122121
rng.sample::<isize, _>(Uniform);
123122
rng.sample::<i8, _>(Uniform);
124123
rng.sample::<i16, _>(Uniform);
125124
rng.sample::<i32, _>(Uniform);
126125
rng.sample::<i64, _>(Uniform);
127126
#[cfg(feature = "i128_support")]
128127
rng.sample::<i128, _>(Uniform);
129-
128+
130129
rng.sample::<usize, _>(Uniform);
131130
rng.sample::<u8, _>(Uniform);
132131
rng.sample::<u16, _>(Uniform);

0 commit comments

Comments
 (0)