Skip to content

Commit 2b6c326

Browse files
authored
Merge pull request #256 from dhardy/distribution
Replace distribution::Sample with Distribution + polymorphism over Rng
2 parents 9c68f34 + b1ea6ef commit 2b6c326

File tree

14 files changed

+788
-687
lines changed

14 files changed

+788
-687
lines changed

benches/distributions/exponential.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ use std::mem::size_of;
22
use test::Bencher;
33
use rand;
44
use rand::distributions::exponential::Exp;
5-
use rand::distributions::Sample;
5+
use rand::distributions::Distribution;
66

77
#[bench]
88
fn rand_exp(b: &mut Bencher) {
99
let mut rng = rand::weak_rng();
10-
let mut exp = Exp::new(2.71828 * 3.14159);
10+
let exp = Exp::new(2.71828 * 3.14159);
1111

1212
b.iter(|| {
1313
for _ in 0..::RAND_BENCH_N {

benches/distributions/gamma.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::mem::size_of;
22
use test::Bencher;
33
use rand;
4-
use rand::distributions::IndependentSample;
4+
use rand::distributions::Distribution;
55
use rand::distributions::gamma::Gamma;
66

77
#[bench]
@@ -11,7 +11,7 @@ fn bench_gamma_large_shape(b: &mut Bencher) {
1111

1212
b.iter(|| {
1313
for _ in 0..::RAND_BENCH_N {
14-
gamma.ind_sample(&mut rng);
14+
gamma.sample(&mut rng);
1515
}
1616
});
1717
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
@@ -24,7 +24,7 @@ fn bench_gamma_small_shape(b: &mut Bencher) {
2424

2525
b.iter(|| {
2626
for _ in 0..::RAND_BENCH_N {
27-
gamma.ind_sample(&mut rng);
27+
gamma.sample(&mut rng);
2828
}
2929
});
3030
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;

benches/distributions/normal.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use std::mem::size_of;
22
use test::Bencher;
33
use rand;
4-
use rand::distributions::Sample;
4+
use rand::distributions::Distribution;
55
use rand::distributions::normal::Normal;
66

77
#[bench]
88
fn rand_normal(b: &mut Bencher) {
99
let mut rng = rand::weak_rng();
10-
let mut normal = Normal::new(-2.71828, 3.14159);
10+
let normal = Normal::new(-2.71828, 3.14159);
1111

1212
b.iter(|| {
1313
for _ in 0..::RAND_BENCH_N {

benches/generators.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const BYTES_LEN: usize = 1024;
99
use std::mem::size_of;
1010
use test::{black_box, Bencher};
1111

12-
use rand::{RngCore, Rng, NewRng, StdRng, OsRng, JitterRng, EntropyRng};
12+
use rand::{RngCore, Rng, SeedableRng, NewRng, StdRng, OsRng, JitterRng, EntropyRng};
1313
use rand::{XorShiftRng, Hc128Rng, IsaacRng, Isaac64Rng, ChaChaRng};
1414
use rand::reseeding::ReseedingRng;
1515

@@ -41,7 +41,7 @@ macro_rules! gen_uint {
4141
($fnn:ident, $ty:ty, $gen:ident) => {
4242
#[bench]
4343
fn $fnn(b: &mut Bencher) {
44-
let mut rng: $gen = OsRng::new().unwrap().gen();
44+
let mut rng = $gen::new().unwrap();
4545
b.iter(|| {
4646
for _ in 0..RAND_BENCH_N {
4747
black_box(rng.gen::<$ty>());
@@ -96,9 +96,9 @@ macro_rules! init_gen {
9696
($fnn:ident, $gen:ident) => {
9797
#[bench]
9898
fn $fnn(b: &mut Bencher) {
99-
let mut rng: XorShiftRng = OsRng::new().unwrap().gen();
99+
let mut rng = XorShiftRng::new().unwrap();
100100
b.iter(|| {
101-
let r2: $gen = rng.gen();
101+
let r2 = $gen::from_rng(&mut rng).unwrap();
102102
black_box(r2);
103103
});
104104
}

src/distributions/exponential.rs

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
//! The exponential distribution.
1212
13-
use {Rng, Rand};
14-
use distributions::{ziggurat, ziggurat_tables, Sample, IndependentSample};
13+
use {Rng};
14+
use distributions::{ziggurat, ziggurat_tables, Distribution};
1515

16-
/// A wrapper around an `f64` to generate Exp(1) random numbers.
16+
/// Samples floating-point numbers according to the exponential distribution,
17+
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
18+
/// sampling with `-rng.gen::<f64>().ln()`, but faster.
1719
///
1820
/// See `Exp` for the general exponential distribution.
1921
///
@@ -27,33 +29,33 @@ use distributions::{ziggurat, ziggurat_tables, Sample, IndependentSample};
2729
/// College, Oxford
2830
///
2931
/// # Example
30-
///
3132
/// ```rust
32-
/// use rand::distributions::exponential::Exp1;
33+
/// use rand::{weak_rng, Rng};
34+
/// use rand::distributions::Exp1;
3335
///
34-
/// let Exp1(x) = rand::random();
35-
/// println!("{}", x);
36+
/// let val: f64 = weak_rng().sample(Exp1);
37+
/// println!("{}", val);
3638
/// ```
3739
#[derive(Clone, Copy, Debug)]
38-
pub struct Exp1(pub f64);
40+
pub struct Exp1;
3941

4042
// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
41-
impl Rand for Exp1 {
43+
impl Distribution<f64> for Exp1 {
4244
#[inline]
43-
fn rand<R:Rng>(rng: &mut R) -> Exp1 {
45+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
4446
#[inline]
4547
fn pdf(x: f64) -> f64 {
4648
(-x).exp()
4749
}
4850
#[inline]
49-
fn zero_case<R:Rng>(rng: &mut R, _u: f64) -> f64 {
51+
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
5052
ziggurat_tables::ZIG_EXP_R - rng.gen::<f64>().ln()
5153
}
5254

53-
Exp1(ziggurat(rng, false,
54-
&ziggurat_tables::ZIG_EXP_X,
55-
&ziggurat_tables::ZIG_EXP_F,
56-
pdf, zero_case))
55+
ziggurat(rng, false,
56+
&ziggurat_tables::ZIG_EXP_X,
57+
&ziggurat_tables::ZIG_EXP_F,
58+
pdf, zero_case)
5759
}
5860
}
5961

@@ -65,10 +67,10 @@ impl Rand for Exp1 {
6567
/// # Example
6668
///
6769
/// ```rust
68-
/// use rand::distributions::{Exp, IndependentSample};
70+
/// use rand::distributions::{Exp, Distribution};
6971
///
7072
/// let exp = Exp::new(2.0);
71-
/// let v = exp.ind_sample(&mut rand::thread_rng());
73+
/// let v = exp.sample(&mut rand::thread_rng());
7274
/// println!("{} is from a Exp(2) distribution", v);
7375
/// ```
7476
#[derive(Clone, Copy, Debug)]
@@ -87,28 +89,25 @@ impl Exp {
8789
}
8890
}
8991

90-
impl Sample<f64> for Exp {
91-
fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 { self.ind_sample(rng) }
92-
}
93-
impl IndependentSample<f64> for Exp {
94-
fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
95-
let Exp1(n) = rng.gen::<Exp1>();
92+
impl Distribution<f64> for Exp {
93+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
94+
let n: f64 = rng.sample(Exp1);
9695
n * self.lambda_inverse
9796
}
9897
}
9998

10099
#[cfg(test)]
101100
mod test {
102-
use distributions::{Sample, IndependentSample};
101+
use distributions::Distribution;
103102
use super::Exp;
104103

105104
#[test]
106105
fn test_exp() {
107-
let mut exp = Exp::new(10.0);
106+
let exp = Exp::new(10.0);
108107
let mut rng = ::test::rng(221);
109108
for _ in 0..1000 {
110109
assert!(exp.sample(&mut rng) >= 0.0);
111-
assert!(exp.ind_sample(&mut rng) >= 0.0);
110+
assert!(exp.sample(&mut rng) >= 0.0);
112111
}
113112
}
114113
#[test]

src/distributions/float.rs

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! Basic floating-point number distributions
12+
13+
14+
/// A distribution to sample floating point numbers uniformly in the open
15+
/// interval `(0, 1)` (not including either endpoint).
16+
///
17+
/// See also: [`Closed01`] for the closed `[0, 1]`; [`Uniform`] for the
18+
/// half-open `[0, 1)`.
19+
///
20+
/// # Example
21+
/// ```rust
22+
/// use rand::{weak_rng, Rng};
23+
/// use rand::distributions::Open01;
24+
///
25+
/// let val: f32 = weak_rng().sample(Open01);
26+
/// println!("f32 from (0,1): {}", val);
27+
/// ```
28+
///
29+
/// [`Uniform`]: struct.Uniform.html
30+
/// [`Closed01`]: struct.Closed01.html
31+
#[derive(Clone, Copy, Debug)]
32+
pub struct Open01;
33+
34+
/// A distribution to sample floating point numbers uniformly in the closed
35+
/// interval `[0, 1]` (including both endpoints).
36+
///
37+
/// See also: [`Open01`] for the open `(0, 1)`; [`Uniform`] for the half-open
38+
/// `[0, 1)`.
39+
///
40+
/// # Example
41+
/// ```rust
42+
/// use rand::{weak_rng, Rng};
43+
/// use rand::distributions::Closed01;
44+
///
45+
/// let val: f32 = weak_rng().sample(Closed01);
46+
/// println!("f32 from [0,1]: {}", val);
47+
/// ```
48+
///
49+
/// [`Uniform`]: struct.Uniform.html
50+
/// [`Open01`]: struct.Open01.html
51+
#[derive(Clone, Copy, Debug)]
52+
pub struct Closed01;
53+
54+
55+
macro_rules! float_impls {
56+
($mod_name:ident, $ty:ty, $mantissa_bits:expr, $method_name:ident) => {
57+
mod $mod_name {
58+
use Rng;
59+
use distributions::{Distribution, Uniform};
60+
use super::{Open01, Closed01};
61+
62+
const SCALE: $ty = (1u64 << $mantissa_bits) as $ty;
63+
64+
impl Distribution<$ty> for Uniform {
65+
/// Generate a floating point number in the half-open
66+
/// interval `[0,1)`.
67+
///
68+
/// See `Closed01` for the closed interval `[0,1]`,
69+
/// and `Open01` for the open interval `(0,1)`.
70+
#[inline]
71+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
72+
rng.$method_name()
73+
}
74+
}
75+
impl Distribution<$ty> for Open01 {
76+
#[inline]
77+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
78+
// add 0.5 * epsilon, so that smallest number is
79+
// greater than 0, and largest number is still
80+
// less than 1, specifically 1 - 0.5 * epsilon.
81+
rng.$method_name() + 0.5 / SCALE
82+
}
83+
}
84+
impl Distribution<$ty> for Closed01 {
85+
#[inline]
86+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
87+
// rescale so that 1.0 - epsilon becomes 1.0
88+
// precisely.
89+
rng.$method_name() * SCALE / (SCALE - 1.0)
90+
}
91+
}
92+
}
93+
}
94+
}
95+
float_impls! { f64_rand_impls, f64, 52, next_f64 }
96+
float_impls! { f32_rand_impls, f32, 23, next_f32 }
97+
98+
99+
#[cfg(test)]
100+
mod tests {
101+
use {Rng, RngCore, impls};
102+
use distributions::{Open01, Closed01};
103+
104+
const EPSILON32: f32 = ::core::f32::EPSILON;
105+
const EPSILON64: f64 = ::core::f64::EPSILON;
106+
107+
struct ConstantRng(u64);
108+
impl RngCore for ConstantRng {
109+
fn next_u32(&mut self) -> u32 {
110+
let ConstantRng(v) = *self;
111+
v as u32
112+
}
113+
fn next_u64(&mut self) -> u64 {
114+
let ConstantRng(v) = *self;
115+
v
116+
}
117+
118+
fn fill_bytes(&mut self, dest: &mut [u8]) {
119+
impls::fill_bytes_via_u64(self, dest)
120+
}
121+
}
122+
123+
#[test]
124+
fn floating_point_edge_cases() {
125+
let mut zeros = ConstantRng(0);
126+
assert_eq!(zeros.gen::<f32>(), 0.0);
127+
assert_eq!(zeros.gen::<f64>(), 0.0);
128+
129+
let mut one = ConstantRng(1);
130+
assert_eq!(one.gen::<f32>(), EPSILON32);
131+
assert_eq!(one.gen::<f64>(), EPSILON64);
132+
133+
let mut max = ConstantRng(!0);
134+
assert_eq!(max.gen::<f32>(), 1.0 - EPSILON32);
135+
assert_eq!(max.gen::<f64>(), 1.0 - EPSILON64);
136+
}
137+
138+
#[test]
139+
fn fp_closed_edge_cases() {
140+
let mut zeros = ConstantRng(0);
141+
assert_eq!(zeros.sample::<f32, _>(Closed01), 0.0);
142+
assert_eq!(zeros.sample::<f64, _>(Closed01), 0.0);
143+
144+
let mut one = ConstantRng(1);
145+
let one32 = one.sample::<f32, _>(Closed01);
146+
let one64 = one.sample::<f64, _>(Closed01);
147+
assert!(EPSILON32 < one32 && one32 < EPSILON32 * 1.01);
148+
assert!(EPSILON64 < one64 && one64 < EPSILON64 * 1.01);
149+
150+
let mut max = ConstantRng(!0);
151+
assert_eq!(max.sample::<f32, _>(Closed01), 1.0);
152+
assert_eq!(max.sample::<f64, _>(Closed01), 1.0);
153+
}
154+
155+
#[test]
156+
fn fp_open_edge_cases() {
157+
let mut zeros = ConstantRng(0);
158+
assert_eq!(zeros.sample::<f32, _>(Open01), 0.0 + EPSILON32 / 2.0);
159+
assert_eq!(zeros.sample::<f64, _>(Open01), 0.0 + EPSILON64 / 2.0);
160+
161+
let mut one = ConstantRng(1);
162+
let one32 = one.sample::<f32, _>(Open01);
163+
let one64 = one.sample::<f64, _>(Open01);
164+
assert!(EPSILON32 < one32 && one32 < EPSILON32 * 2.0);
165+
assert!(EPSILON64 < one64 && one64 < EPSILON64 * 2.0);
166+
167+
let mut max = ConstantRng(!0);
168+
assert_eq!(max.sample::<f32, _>(Open01), 1.0 - EPSILON32 / 2.0);
169+
assert_eq!(max.sample::<f64, _>(Open01), 1.0 - EPSILON64 / 2.0);
170+
}
171+
172+
#[test]
173+
fn rand_open() {
174+
// this is unlikely to catch an incorrect implementation that
175+
// generates exactly 0 or 1, but it keeps it sane.
176+
let mut rng = ::test::rng(510);
177+
for _ in 0..1_000 {
178+
// strict inequalities
179+
let f: f64 = rng.sample(Open01);
180+
assert!(0.0 < f && f < 1.0);
181+
182+
let f: f32 = rng.sample(Open01);
183+
assert!(0.0 < f && f < 1.0);
184+
}
185+
}
186+
187+
#[test]
188+
fn rand_closed() {
189+
let mut rng = ::test::rng(511);
190+
for _ in 0..1_000 {
191+
// strict inequalities
192+
let f: f64 = rng.sample(Closed01);
193+
assert!(0.0 <= f && f <= 1.0);
194+
195+
let f: f32 = rng.sample(Closed01);
196+
assert!(0.0 <= f && f <= 1.0);
197+
}
198+
}
199+
}

0 commit comments

Comments
 (0)