Skip to content

Commit a3a9fc3

Browse files
committed
Implement Rng.gen_ratio() and Bernoulli::new_ratio()
1 parent 276c8be commit a3a9fc3

File tree

3 files changed

+113
-11
lines changed

3 files changed

+113
-11
lines changed

benches/misc.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ fn misc_gen_bool_var(b: &mut Bencher) {
3636
})
3737
}
3838

39+
#[bench]
40+
fn misc_gen_ratio_const(b: &mut Bencher) {
41+
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
42+
b.iter(|| {
43+
let mut accum = true;
44+
for _ in 0..::RAND_BENCH_N {
45+
accum ^= rng.gen_ratio(2, 3);
46+
}
47+
accum
48+
})
49+
}
50+
51+
#[bench]
52+
fn misc_gen_ratio_var(b: &mut Bencher) {
53+
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
54+
b.iter(|| {
55+
let mut accum = true;
56+
for i in 2..(::RAND_BENCH_N as u32 + 2) {
57+
accum ^= rng.gen_ratio(i, i + 1);
58+
}
59+
accum
60+
})
61+
}
62+
3963
#[bench]
4064
fn misc_bernoulli_const(b: &mut Bencher) {
4165
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();

src/distributions/bernoulli.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,28 @@ impl Bernoulli {
6767
};
6868
Bernoulli { p_int }
6969
}
70+
71+
/// Construct a new `Bernoulli` with the probability of success of
72+
/// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return
73+
/// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`.
74+
///
75+
/// If `numerator == denominator` then the returned `Bernoulli` will always
76+
/// return `true`. If `numerator == 0` it will always return `false`.
77+
///
78+
/// # Panics
79+
///
80+
/// If `denominator == 0` or `numerator > denominator`.
81+
///
82+
#[inline]
83+
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
84+
assert!(numerator <= denominator);
85+
if numerator == denominator {
86+
return Bernoulli { p_int: ::core::u64::MAX }
87+
}
88+
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
89+
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
90+
Bernoulli { p_int }
91+
}
7092
}
7193

7294
impl Distribution<bool> for Bernoulli {
@@ -103,18 +125,27 @@ mod test {
103125
#[test]
104126
fn test_average() {
105127
const P: f64 = 0.3;
106-
let d = Bernoulli::new(P);
107-
const N: u32 = 10_000_000;
128+
const NUM: u32 = 3;
129+
const DENOM: u32 = 10;
130+
let d1 = Bernoulli::new(P);
131+
let d2 = Bernoulli::from_ratio(NUM, DENOM);
132+
const N: u32 = 100_000;
108133

109-
let mut sum: u32 = 0;
134+
let mut sum1: u32 = 0;
135+
let mut sum2: u32 = 0;
110136
let mut rng = ::test::rng(2);
111137
for _ in 0..N {
112-
if d.sample(&mut rng) {
113-
sum += 1;
138+
if d1.sample(&mut rng) {
139+
sum1 += 1;
140+
}
141+
if d2.sample(&mut rng) {
142+
sum2 += 1;
114143
}
115144
}
116-
let avg = (sum as f64) / (N as f64);
145+
let avg1 = (sum1 as f64) / (N as f64);
146+
assert!((avg1 - P).abs() < 5e-3);
117147

118-
assert!((avg - P).abs() < 1e-3);
148+
let avg2 = (sum2 as f64) / (N as f64);
149+
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
119150
}
120151
}

src/lib.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ pub trait Rng: RngCore {
387387
/// ```
388388
///
389389
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
390-
fn gen_range<T: PartialOrd + SampleUniform>(&mut self, low: T, high: T) -> T {
390+
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
391391
T::Sampler::sample_single(low, high, self)
392392
}
393393

@@ -509,7 +509,8 @@ pub trait Rng: RngCore {
509509

510510
/// Return a bool with a probability `p` of being true.
511511
///
512-
/// This is a wrapper around [`distributions::Bernoulli`].
512+
/// See also the [`Bernoulli`] distribution, which may be faster if
513+
/// sampling from the same probability repeatedly.
513514
///
514515
/// # Example
515516
///
@@ -522,15 +523,44 @@ pub trait Rng: RngCore {
522523
///
523524
/// # Panics
524525
///
525-
/// If `p` < 0 or `p` > 1.
526+
/// If `p < 0` or `p > 1`.
526527
///
527-
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
528+
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
528529
#[inline]
529530
fn gen_bool(&mut self, p: f64) -> bool {
530531
let d = distributions::Bernoulli::new(p);
531532
self.sample(d)
532533
}
533534

535+
/// Return a bool with a probability of `numerator/denominator` of being
536+
/// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of
537+
/// returning true. If `numerator == denominator`, then the returned value
538+
/// is guaranteed to be `true`. If `numerator == 0`, then the returned
539+
/// value is guaranteed to be `false`.
540+
///
541+
/// See also the [`Bernoulli`] distribution, which may be faster if
542+
/// sampling from the same `numerator` and `denominator` repeatedly.
543+
///
544+
/// # Panics
545+
///
546+
/// If `denominator == 0` or `numerator > denominator`.
547+
///
548+
/// # Example
549+
///
550+
/// ```
551+
/// use rand::{thread_rng, Rng};
552+
///
553+
/// let mut rng = thread_rng();
554+
/// println!("{}", rng.gen_ratio(2, 3));
555+
/// ```
556+
///
557+
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
558+
#[inline]
559+
fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool {
560+
let d = distributions::Bernoulli::from_ratio(numerator, denominator);
561+
self.sample(d)
562+
}
563+
534564
/// Return a random element from `values`.
535565
///
536566
/// Return `None` if `values` is empty.
@@ -1017,4 +1047,21 @@ mod test {
10171047
(u8, i8, u16, i16, u32, i32, u64, i64),
10181048
(f32, (f64, (f64,)))) = random();
10191049
}
1050+
1051+
#[test]
1052+
fn test_gen_ratio_average() {
1053+
const NUM: u32 = 3;
1054+
const DENOM: u32 = 10;
1055+
const N: u32 = 100_000;
1056+
1057+
let mut sum: u32 = 0;
1058+
let mut rng = rng(111);
1059+
for _ in 0..N {
1060+
if rng.gen_ratio(NUM, DENOM) {
1061+
sum += 1;
1062+
}
1063+
}
1064+
let avg = (sum as f64) / (N as f64);
1065+
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
1066+
}
10201067
}

0 commit comments

Comments
 (0)