Skip to content

Commit 8558b22

Browse files
authored
Merge pull request #96 from fizyk20/discrete
Add binomial and Poisson distributions
2 parents b146ee6 + 38ee0f8 commit 8558b22

File tree

4 files changed

+360
-0
lines changed

4 files changed

+360
-0
lines changed

src/distributions/binomial.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! The binomial distribution.
12+
13+
use Rng;
14+
use distributions::Distribution;
15+
use distributions::log_gamma::log_gamma;
16+
use std::f64::consts::PI;
17+
18+
/// The binomial distribution `Binomial(n, p)`.
19+
///
20+
/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
21+
///
22+
/// # Example
23+
///
24+
/// ```rust
25+
/// use rand::distributions::{Binomial, Distribution};
26+
///
27+
/// let bin = Binomial::new(20, 0.3);
28+
/// let v = bin.sample(&mut rand::thread_rng());
29+
/// println!("{} is from a binomial distribution", v);
30+
/// ```
31+
#[derive(Clone, Copy, Debug)]
32+
pub struct Binomial {
33+
n: u64, // number of trials
34+
p: f64, // probability of success
35+
}
36+
37+
impl Binomial {
38+
/// Construct a new `Binomial` with the given shape parameters
39+
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
40+
pub fn new(n: u64, p: f64) -> Binomial {
41+
assert!(p > 0.0, "Binomial::new called with p <= 0");
42+
assert!(p < 1.0, "Binomial::new called with p >= 1");
43+
Binomial { n: n, p: p }
44+
}
45+
}
46+
47+
impl Distribution<u64> for Binomial {
48+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
49+
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
50+
// switch p so that it is less than 0.5 - this allows for lower expected values
51+
// we will just invert the result at the end
52+
let p = if self.p <= 0.5 {
53+
self.p
54+
} else {
55+
1.0 - self.p
56+
};
57+
58+
// expected value of the sample
59+
let expected = self.n as f64 * p;
60+
61+
let result =
62+
// for low expected values we just simulate n drawings
63+
if expected < 25.0 {
64+
let mut lresult = 0.0;
65+
for _ in 0 .. self.n {
66+
if rng.gen::<f64>() < p {
67+
lresult += 1.0;
68+
}
69+
}
70+
lresult
71+
}
72+
// high expected value - do the rejection method
73+
else {
74+
// prepare some cached values
75+
let float_n = self.n as f64;
76+
let ln_fact_n = log_gamma(float_n + 1.0);
77+
let pc = 1.0 - p;
78+
let log_p = p.ln();
79+
let log_pc = pc.ln();
80+
let sq = (expected * (2.0 * pc)).sqrt();
81+
82+
let mut lresult;
83+
84+
loop {
85+
let mut comp_dev: f64;
86+
// we use the lorentzian distribution as the comparison distribution
87+
// f(x) ~ 1/(1+x/^2)
88+
loop {
89+
// draw from the lorentzian distribution
90+
comp_dev = (PI*rng.gen::<f64>()).tan();
91+
// shift the peak of the comparison ditribution
92+
lresult = expected + sq * comp_dev;
93+
// repeat the drawing until we are in the range of possible values
94+
if lresult >= 0.0 && lresult < float_n + 1.0 {
95+
break;
96+
}
97+
}
98+
99+
// the result should be discrete
100+
lresult = lresult.floor();
101+
102+
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
103+
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
104+
// this is the binomial probability divided by the comparison probability
105+
// we will generate a uniform random value and if it is larger than this,
106+
// we interpret it as a value falling out of the distribution and repeat
107+
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
108+
109+
if comparison_coeff >= rng.gen() {
110+
break;
111+
}
112+
}
113+
114+
lresult
115+
};
116+
117+
// invert the result for p < 0.5
118+
if p != self.p {
119+
self.n - result as u64
120+
} else {
121+
result as u64
122+
}
123+
}
124+
}
125+
126+
#[cfg(test)]
127+
mod test {
128+
use distributions::Distribution;
129+
use super::Binomial;
130+
131+
#[test]
132+
fn test_binomial() {
133+
let binomial = Binomial::new(150, 0.1);
134+
let mut rng = ::test::rng(123);
135+
let mut sum = 0;
136+
for _ in 0..1000 {
137+
sum += binomial.sample(&mut rng);
138+
}
139+
let avg = (sum as f64) / 1000.0;
140+
println!("Binomial average: {}", avg);
141+
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
142+
}
143+
144+
#[test]
145+
#[should_panic]
146+
#[cfg_attr(target_env = "msvc", ignore)]
147+
fn test_binomial_invalid_lambda_zero() {
148+
Binomial::new(20, 0.0);
149+
}
150+
#[test]
151+
#[should_panic]
152+
#[cfg_attr(target_env = "msvc", ignore)]
153+
fn test_binomial_invalid_lambda_neg() {
154+
Binomial::new(20, -10.0);
155+
}
156+
}

src/distributions/log_gamma.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
12+
/// function) using the Lanczos approximation.
13+
///
14+
/// The approximation expresses the gamma function as:
15+
/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
16+
/// `g` is an arbitrary constant; we use the approximation with `g=5`.
17+
///
18+
/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
19+
/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
20+
///
21+
/// `Ag(z)` is an infinite series with coefficients that can be calculated
22+
/// ahead of time - we use just the first 6 terms, which is good enough
23+
/// for most purposes.
24+
pub fn log_gamma(x: f64) -> f64 {
25+
// precalculated 6 coefficients for the first 6 terms of the series
26+
let coefficients: [f64; 6] = [
27+
76.18009172947146,
28+
-86.50532032941677,
29+
24.01409824083091,
30+
-1.231739572450155,
31+
0.1208650973866179e-2,
32+
-0.5395239384953e-5,
33+
];
34+
35+
// (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
36+
let tmp = x + 5.5;
37+
let log = (x + 0.5) * tmp.ln() - tmp;
38+
39+
// the first few terms of the series for Ag(x)
40+
let mut a = 1.000000000190015;
41+
let mut denom = x;
42+
for j in 0..6 {
43+
denom += 1.0;
44+
a += coefficients[j] / denom;
45+
}
46+
47+
// get everything together
48+
// a is Ag(x)
49+
// 2.5066... is sqrt(2pi)
50+
return log + (2.5066282746310005 * a / x).ln();
51+
}

src/distributions/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
2525
pub use self::normal::{Normal, LogNormal, StandardNormal};
2626
#[cfg(feature="std")]
2727
pub use self::exponential::{Exp, Exp1};
28+
#[cfg(feature = "std")]
29+
pub use self::poisson::Poisson;
30+
#[cfg(feature = "std")]
31+
pub use self::binomial::Binomial;
2832

2933
pub mod range;
3034
#[cfg(feature="std")]
@@ -33,9 +37,14 @@ pub mod gamma;
3337
pub mod normal;
3438
#[cfg(feature="std")]
3539
pub mod exponential;
40+
#[cfg(feature = "std")]
41+
pub mod poisson;
42+
#[cfg(feature = "std")]
43+
pub mod binomial;
3644

3745
mod float;
3846
mod integer;
47+
mod log_gamma;
3948
mod other;
4049
#[cfg(feature="std")]
4150
mod ziggurat_tables;

src/distributions/poisson.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! The Poisson distribution.
12+
13+
use Rng;
14+
use distributions::Distribution;
15+
use distributions::log_gamma::log_gamma;
16+
use std::f64::consts::PI;
17+
18+
/// The Poisson distribution `Poisson(lambda)`.
19+
///
20+
/// This distribution has a density function:
21+
/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`.
22+
///
23+
/// # Example
24+
///
25+
/// ```rust
26+
/// use rand::distributions::{Poisson, Distribution};
27+
///
28+
/// let poi = Poisson::new(2.0);
29+
/// let v = poi.sample(&mut rand::thread_rng());
30+
/// println!("{} is from a Poisson(2) distribution", v);
31+
/// ```
32+
#[derive(Clone, Copy, Debug)]
33+
pub struct Poisson {
34+
lambda: f64,
35+
// precalculated values
36+
exp_lambda: f64,
37+
log_lambda: f64,
38+
sqrt_2lambda: f64,
39+
magic_val: f64,
40+
}
41+
42+
impl Poisson {
43+
/// Construct a new `Poisson` with the given shape parameter
44+
/// `lambda`. Panics if `lambda <= 0`.
45+
pub fn new(lambda: f64) -> Poisson {
46+
assert!(lambda > 0.0, "Poisson::new called with lambda <= 0");
47+
let log_lambda = lambda.ln();
48+
Poisson {
49+
lambda: lambda,
50+
exp_lambda: (-lambda).exp(),
51+
log_lambda: log_lambda,
52+
sqrt_2lambda: (2.0 * lambda).sqrt(),
53+
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda),
54+
}
55+
}
56+
}
57+
58+
impl Distribution<u64> for Poisson {
59+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
60+
// using the algorithm from Numerical Recipes in C
61+
62+
// for low expected values use the Knuth method
63+
if self.lambda < 12.0 {
64+
let mut result = 0;
65+
let mut p = 1.0;
66+
while p > self.exp_lambda {
67+
p *= rng.gen::<f64>();
68+
result += 1;
69+
}
70+
result - 1
71+
}
72+
// high expected values - rejection method
73+
else {
74+
let mut int_result: u64;
75+
76+
loop {
77+
let mut result;
78+
let mut comp_dev;
79+
80+
// we use the lorentzian distribution as the comparison distribution
81+
// f(x) ~ 1/(1+x/^2)
82+
loop {
83+
// draw from the lorentzian distribution
84+
comp_dev = (PI * rng.gen::<f64>()).tan();
85+
// shift the peak of the comparison ditribution
86+
result = self.sqrt_2lambda * comp_dev + self.lambda;
87+
// repeat the drawing until we are in the range of possible values
88+
if result >= 0.0 {
89+
break;
90+
}
91+
}
92+
// now the result is a random variable greater than 0 with Lorentzian distribution
93+
// the result should be an integer value
94+
result = result.floor();
95+
int_result = result as u64;
96+
97+
// this is the ratio of the Poisson distribution to the comparison distribution
98+
// the magic value scales the distribution function to a range of approximately 0-1
99+
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
100+
// this doesn't change the resulting distribution, only increases the rate of failed drawings
101+
let check = 0.9 * (1.0 + comp_dev * comp_dev)
102+
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp();
103+
104+
// check with uniform random value - if below the threshold, we are within the target distribution
105+
if rng.gen::<f64>() <= check {
106+
break;
107+
}
108+
}
109+
int_result
110+
}
111+
}
112+
}
113+
114+
#[cfg(test)]
115+
mod test {
116+
use distributions::Distribution;
117+
use super::Poisson;
118+
119+
#[test]
120+
fn test_poisson() {
121+
let poisson = Poisson::new(10.0);
122+
let mut rng = ::test::rng(123);
123+
let mut sum = 0;
124+
for _ in 0..1000 {
125+
sum += poisson.sample(&mut rng);
126+
}
127+
let avg = (sum as f64) / 1000.0;
128+
println!("Poisson average: {}", avg);
129+
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
130+
}
131+
132+
#[test]
133+
#[should_panic]
134+
#[cfg_attr(target_env = "msvc", ignore)]
135+
fn test_poisson_invalid_lambda_zero() {
136+
Poisson::new(0.0);
137+
}
138+
#[test]
139+
#[should_panic]
140+
#[cfg_attr(target_env = "msvc", ignore)]
141+
fn test_poisson_invalid_lambda_neg() {
142+
Poisson::new(-10.0);
143+
}
144+
}

0 commit comments

Comments
 (0)