|  | 
|  | 1 | +// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | 
|  | 2 | +// file at the top-level directory of this distribution and at | 
|  | 3 | +// https://rust-lang.org/COPYRIGHT. | 
|  | 4 | +// | 
|  | 5 | +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | 
|  | 6 | +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | 
|  | 7 | +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | 
|  | 8 | +// option. This file may not be copied, modified, or distributed | 
|  | 9 | +// except according to those terms. | 
|  | 10 | + | 
|  | 11 | +//! The binomial distribution. | 
|  | 12 | +
 | 
|  | 13 | +use Rng; | 
|  | 14 | +use distributions::Distribution; | 
|  | 15 | +use distributions::log_gamma::log_gamma; | 
|  | 16 | +use std::f64::consts::PI; | 
|  | 17 | + | 
|  | 18 | +/// The binomial distribution `Binomial(n, p)`. | 
|  | 19 | +/// | 
|  | 20 | +/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. | 
|  | 21 | +/// | 
|  | 22 | +/// # Example | 
|  | 23 | +/// | 
|  | 24 | +/// ```rust | 
|  | 25 | +/// use rand::distributions::{Binomial, Distribution}; | 
|  | 26 | +/// | 
|  | 27 | +/// let bin = Binomial::new(20, 0.3); | 
|  | 28 | +/// let v = bin.sample(&mut rand::thread_rng()); | 
|  | 29 | +/// println!("{} is from a binomial distribution", v); | 
|  | 30 | +/// ``` | 
|  | 31 | +#[derive(Clone, Copy, Debug)] | 
|  | 32 | +pub struct Binomial { | 
|  | 33 | +    n: u64, // number of trials | 
|  | 34 | +    p: f64, // probability of success | 
|  | 35 | +} | 
|  | 36 | + | 
|  | 37 | +impl Binomial { | 
|  | 38 | +    /// Construct a new `Binomial` with the given shape parameters | 
|  | 39 | +    /// `n`, `p`. Panics if `p <= 0` or `p >= 1`. | 
|  | 40 | +    pub fn new(n: u64, p: f64) -> Binomial { | 
|  | 41 | +        assert!(p > 0.0, "Binomial::new called with `p` <= 0"); | 
|  | 42 | +        assert!(p < 1.0, "Binomial::new called with `p` >= 1"); | 
|  | 43 | +        Binomial { n: n, p: p } | 
|  | 44 | +    } | 
|  | 45 | +} | 
|  | 46 | + | 
|  | 47 | +impl Distribution<u64> for Binomial { | 
|  | 48 | +    fn sample<R: Rng>(&self, rng: &mut R) -> u64 { | 
|  | 49 | +        // binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k | 
|  | 50 | +        // switch p so that it is less than 0.5 - this allows for lower expected values | 
|  | 51 | +        // we will just invert the result at the end | 
|  | 52 | +        let p = if self.p <= 0.5 { | 
|  | 53 | +            self.p | 
|  | 54 | +        } else { | 
|  | 55 | +            1.0 - self.p | 
|  | 56 | +        }; | 
|  | 57 | + | 
|  | 58 | +        // expected value of the sample | 
|  | 59 | +        let expected = self.n as f64 * p; | 
|  | 60 | + | 
|  | 61 | +        let result = | 
|  | 62 | +            // for low expected values we just simulate n drawings | 
|  | 63 | +            if expected < 25.0 { | 
|  | 64 | +                let mut lresult = 0.0; | 
|  | 65 | +                for _ in 0 .. self.n { | 
|  | 66 | +                    if rng.gen::<f64>() < p { | 
|  | 67 | +                        lresult += 1.0; | 
|  | 68 | +                    } | 
|  | 69 | +                } | 
|  | 70 | +                lresult | 
|  | 71 | +            } | 
|  | 72 | +            // high expected value - do the rejection method | 
|  | 73 | +            else { | 
|  | 74 | +                // prepare some cached values | 
|  | 75 | +                let float_n = self.n as f64; | 
|  | 76 | +                let ln_fact_n = log_gamma(float_n + 1.0); | 
|  | 77 | +                let pc = 1.0 - p; | 
|  | 78 | +                let log_p = p.ln(); | 
|  | 79 | +                let log_pc = pc.ln(); | 
|  | 80 | +                let sq = (expected * (2.0 * pc)).sqrt(); | 
|  | 81 | + | 
|  | 82 | +                let mut lresult; | 
|  | 83 | + | 
|  | 84 | +                loop { | 
|  | 85 | +                    let mut comp_dev: f64; | 
|  | 86 | +                    // we use the lorentzian distribution as the comparison distribution | 
|  | 87 | +                    // f(x) ~ 1/(1+x/^2) | 
|  | 88 | +                    loop { | 
|  | 89 | +                        // draw from the lorentzian distribution | 
|  | 90 | +                        comp_dev = (PI*rng.gen::<f64>()).tan(); | 
|  | 91 | +                        // shift the peak of the comparison ditribution | 
|  | 92 | +                        lresult = expected + sq * comp_dev; | 
|  | 93 | +                        // repeat the drawing until we are in the range of possible values | 
|  | 94 | +                        if lresult >= 0.0 && lresult < float_n + 1.0 { | 
|  | 95 | +                            break; | 
|  | 96 | +                        } | 
|  | 97 | +                    } | 
|  | 98 | + | 
|  | 99 | +                    // the result should be discrete | 
|  | 100 | +                    lresult = lresult.floor(); | 
|  | 101 | + | 
|  | 102 | +                    let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) - | 
|  | 103 | +                        log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc; | 
|  | 104 | +                    // this is the binomial probability divided by the comparison probability | 
|  | 105 | +                    // we will generate a uniform random value and if it is larger than this, | 
|  | 106 | +                    // we interpret it as a value falling out of the distribution and repeat | 
|  | 107 | +                    let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev)); | 
|  | 108 | + | 
|  | 109 | +                    if comparison_coeff >= rng.gen() { | 
|  | 110 | +                        break; | 
|  | 111 | +                    } | 
|  | 112 | +                } | 
|  | 113 | + | 
|  | 114 | +                lresult | 
|  | 115 | +            }; | 
|  | 116 | + | 
|  | 117 | +        // invert the result for p < 0.5 | 
|  | 118 | +        if p != self.p { | 
|  | 119 | +            self.n - result as u64 | 
|  | 120 | +        } else { | 
|  | 121 | +            result as u64 | 
|  | 122 | +        } | 
|  | 123 | +    } | 
|  | 124 | +} | 
|  | 125 | + | 
|  | 126 | +#[cfg(test)] | 
|  | 127 | +mod test { | 
|  | 128 | +    use distributions::Distribution; | 
|  | 129 | +    use super::Binomial; | 
|  | 130 | + | 
|  | 131 | +    #[test] | 
|  | 132 | +    fn test_binomial() { | 
|  | 133 | +        let binomial = Binomial::new(150, 0.1); | 
|  | 134 | +        let mut rng = ::test::rng(123); | 
|  | 135 | +        let mut sum = 0; | 
|  | 136 | +        for _ in 0..1000 { | 
|  | 137 | +            sum += binomial.sample(&mut rng); | 
|  | 138 | +        } | 
|  | 139 | +        let avg = (sum as f64) / 1000.0; | 
|  | 140 | +        println!("Binomial average: {}", avg); | 
|  | 141 | +        assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | 
|  | 142 | +    } | 
|  | 143 | + | 
|  | 144 | +    #[test] | 
|  | 145 | +    #[should_panic] | 
|  | 146 | +    #[cfg_attr(target_env = "msvc", ignore)] | 
|  | 147 | +    fn test_binomial_invalid_lambda_zero() { | 
|  | 148 | +        Binomial::new(20, 0.0); | 
|  | 149 | +    } | 
|  | 150 | +    #[test] | 
|  | 151 | +    #[should_panic] | 
|  | 152 | +    #[cfg_attr(target_env = "msvc", ignore)] | 
|  | 153 | +    fn test_binomial_invalid_lambda_neg() { | 
|  | 154 | +        Binomial::new(20, -10.0); | 
|  | 155 | +    } | 
|  | 156 | +} | 
0 commit comments