Skip to content

Commit e47c5a9

Browse files
authored
Merge pull request #740 from vks/faster-binomial2
Binomial: Faster sampling for n * p >= 10
2 parents 1eef88c + f8149ab commit e47c5a9

File tree

1 file changed

+174
-45
lines changed

1 file changed

+174
-45
lines changed

src/distributions/binomial.rs

+174-45
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
//! The binomial distribution.
1111
1212
use Rng;
13-
use distributions::{Distribution, Cauchy};
14-
use distributions::utils::log_gamma;
13+
use distributions::{Distribution, Uniform};
1514

1615
/// The binomial distribution `Binomial(n, p)`.
1716
///
@@ -47,6 +46,13 @@ impl Binomial {
4746
}
4847
}
4948

49+
/// Convert a `f64` to an `i64`, panicing on overflow.
50+
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
51+
fn f64_to_i64(x: f64) -> i64 {
52+
assert!(x < (::std::i64::MAX as f64));
53+
x as i64
54+
}
55+
5056
impl Distribution<u64> for Binomial {
5157
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
5258
// Handle these values directly.
@@ -56,25 +62,33 @@ impl Distribution<u64> for Binomial {
5662
return self.n;
5763
}
5864

59-
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
60-
// switch p so that it is less than 0.5 - this allows for lower expected values
61-
// we will just invert the result at the end
65+
// The binomial distribution is symmetrical with respect to p -> 1-p,
66+
// k -> n-k switch p so that it is less than 0.5 - this allows for lower
67+
// expected values we will just invert the result at the end
6268
let p = if self.p <= 0.5 {
6369
self.p
6470
} else {
6571
1.0 - self.p
6672
};
6773

6874
let result;
75+
let q = 1. - p;
6976

7077
// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
71-
// transformation of the binomial distribution is more efficient:
78+
// transformation of the binomial distribution is efficient. Otherwise,
79+
// the BTPE algorithm is used.
7280
//
7381
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
7482
// random variate generation. Commun. ACM 31, 2 (February 1988),
7583
// 216-222. http://dx.doi.org/10.1145/42372.42381
76-
if (self.n as f64) * p < 10. && self.n <= (::std::i32::MAX as u64) {
77-
let q = 1. - p;
84+
85+
// Threshold for prefering the BINV algorithm. The paper suggests 10,
86+
// Ranlib uses 30, and GSL uses 14.
87+
const BINV_THRESHOLD: f64 = 10.;
88+
89+
if (self.n as f64) * p < BINV_THRESHOLD &&
90+
self.n <= (::std::i32::MAX as u64) {
91+
// Use the BINV algorithm.
7892
let s = p / q;
7993
let a = ((self.n + 1) as f64) * s;
8094
let mut r = q.powi(self.n as i32);
@@ -87,52 +101,165 @@ impl Distribution<u64> for Binomial {
87101
}
88102
result = x;
89103
} else {
90-
// FIXME: Using the BTPE algorithm is probably faster.
91-
92-
// prepare some cached values
93-
let float_n = self.n as f64;
94-
let ln_fact_n = log_gamma(float_n + 1.0);
95-
let pc = 1.0 - p;
96-
let log_p = p.ln();
97-
let log_pc = pc.ln();
98-
let expected = self.n as f64 * p;
99-
let sq = (expected * (2.0 * pc)).sqrt();
100-
let mut lresult;
101-
102-
// we use the Cauchy distribution as the comparison distribution
103-
// f(x) ~ 1/(1+x^2)
104-
let cauchy = Cauchy::new(0.0, 1.0);
104+
// Use the BTPE algorithm.
105+
106+
// Threshold for using the squeeze algorithm. This can be freely
107+
// chosen based on performance. Ranlib and GSL use 20.
108+
const SQUEEZE_THRESHOLD: i64 = 20;
109+
110+
// Step 0: Calculate constants as functions of `n` and `p`.
111+
let n = self.n as f64;
112+
let np = n * p;
113+
let npq = np * q;
114+
let f_m = np + p;
115+
let m = f64_to_i64(f_m);
116+
// radius of triangle region, since height=1 also area of region
117+
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
118+
// tip of triangle
119+
let x_m = (m as f64) + 0.5;
120+
// left edge of triangle
121+
let x_l = x_m - p1;
122+
// right edge of triangle
123+
let x_r = x_m + p1;
124+
let c = 0.134 + 20.5 / (15.3 + (m as f64));
125+
// p1 + area of parallelogram region
126+
let p2 = p1 * (1. + 2. * c);
127+
128+
fn lambda(a: f64) -> f64 {
129+
a * (1. + 0.5 * a)
130+
}
131+
132+
let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
133+
let lambda_r = lambda((x_r - f_m) / (x_r * q));
134+
// p1 + area of left tail
135+
let p3 = p2 + c / lambda_l;
136+
// p1 + area of right tail
137+
let p4 = p3 + c / lambda_r;
138+
139+
// return value
140+
let mut y: i64;
141+
142+
let gen_u = Uniform::new(0., p4);
143+
let gen_v = Uniform::new(0., 1.);
144+
105145
loop {
106-
let mut comp_dev: f64;
107-
loop {
108-
// draw from the Cauchy distribution
109-
comp_dev = rng.sample(cauchy);
110-
// shift the peak of the comparison ditribution
111-
lresult = expected + sq * comp_dev;
112-
// repeat the drawing until we are in the range of possible values
113-
if lresult >= 0.0 && lresult < float_n + 1.0 {
114-
break;
146+
// Step 1: Generate `u` for selecting the region. If region 1 is
147+
// selected, generate a triangularly distributed variate.
148+
let u = gen_u.sample(rng);
149+
let mut v = gen_v.sample(rng);
150+
if !(u > p1) {
151+
y = f64_to_i64(x_m - p1 * v + u);
152+
break;
153+
}
154+
155+
if !(u > p2) {
156+
// Step 2: Region 2, parallelograms. Check if region 2 is
157+
// used. If so, generate `y`.
158+
let x = x_l + (u - p1) / c;
159+
v = v * c + 1.0 - (x - x_m).abs() / p1;
160+
if v > 1. {
161+
continue;
162+
} else {
163+
y = f64_to_i64(x);
164+
}
165+
} else if !(u > p3) {
166+
// Step 3: Region 3, left exponential tail.
167+
y = f64_to_i64(x_l + v.ln() / lambda_l);
168+
if y < 0 {
169+
continue;
170+
} else {
171+
v *= (u - p2) * lambda_l;
172+
}
173+
} else {
174+
// Step 4: Region 4, right exponential tail.
175+
y = f64_to_i64(x_r - v.ln() / lambda_r);
176+
if y > 0 && (y as u64) > self.n {
177+
continue;
178+
} else {
179+
v *= (u - p3) * lambda_r;
115180
}
116181
}
117182

118-
// the result should be discrete
119-
lresult = lresult.floor();
183+
// Step 5: Acceptance/rejection comparison.
120184

121-
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
122-
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
123-
// this is the binomial probability divided by the comparison probability
124-
// we will generate a uniform random value and if it is larger than this,
125-
// we interpret it as a value falling out of the distribution and repeat
126-
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
185+
// Step 5.0: Test for appropriate method of evaluating f(y).
186+
let k = (y - m).abs();
187+
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
188+
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
189+
// search from the mode.
190+
let s = p / q;
191+
let a = s * (n + 1.);
192+
let mut f = 1.0;
193+
if m < y {
194+
let mut i = m;
195+
loop {
196+
i += 1;
197+
f *= a / (i as f64) - s;
198+
if i == y {
199+
break;
200+
}
201+
}
202+
} else if m > y {
203+
let mut i = y;
204+
loop {
205+
i += 1;
206+
f /= a / (i as f64) - s;
207+
if i == m {
208+
break;
209+
}
210+
}
211+
}
212+
if v > f {
213+
continue;
214+
} else {
215+
break;
216+
}
217+
}
127218

128-
if comparison_coeff >= rng.gen() {
219+
// Step 5.2: Squeezing. Check the value of ln(v) againts upper and
220+
// lower bound of ln(f(y)).
221+
let k = k as f64;
222+
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5);
223+
let t = -0.5 * k*k / npq;
224+
let alpha = v.ln();
225+
if alpha < t - rho {
129226
break;
130227
}
228+
if alpha > t + rho {
229+
continue;
230+
}
231+
232+
// Step 5.3: Final acceptance/rejection test.
233+
let x1 = (y + 1) as f64;
234+
let f1 = (m + 1) as f64;
235+
let z = (f64_to_i64(n) + 1 - m) as f64;
236+
let w = (f64_to_i64(n) - y + 1) as f64;
237+
238+
fn stirling(a: f64) -> f64 {
239+
let a2 = a * a;
240+
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
241+
}
242+
243+
if alpha > x_m * (f1 / x1).ln()
244+
+ (n - (m as f64) + 0.5) * (z / w).ln()
245+
+ ((y - m) as f64) * (w * p / (x1 * q)).ln()
246+
// We use the signs from the GSL implementation, which are
247+
// different than the ones in the reference. According to
248+
// the GSL authors, the new signs were verified to be
249+
// correct by one of the original designers of the
250+
// algorithm.
251+
+ stirling(f1) + stirling(z) - stirling(x1) - stirling(w)
252+
{
253+
continue;
254+
}
255+
256+
break;
131257
}
132-
result = lresult as u64;
258+
assert!(y >= 0);
259+
result = y as u64;
133260
}
134261

135-
// invert the result for p < 0.5
262+
// Invert the result for p < 0.5.
136263
if p != self.p {
137264
self.n - result
138265
} else {
@@ -157,12 +284,14 @@ mod test {
157284
for i in results.iter_mut() { *i = binomial.sample(rng) as f64; }
158285

159286
let mean = results.iter().sum::<f64>() / results.len() as f64;
160-
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
287+
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0,
288+
"mean: {}, expected_mean: {}", mean, expected_mean);
161289

162290
let variance =
163291
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>()
164292
/ results.len() as f64;
165-
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
293+
assert!((variance - expected_variance).abs() < expected_variance / 10.0,
294+
"variance: {}, expected_variance: {}", variance, expected_variance);
166295
}
167296

168297
#[test]

0 commit comments

Comments
 (0)