Skip to content

Commit 942a3d5

Browse files
committed
Implement Rng.gen_ratio() and Bernoulli::new_ratio()
1 parent 3b0a884 commit 942a3d5

File tree

3 files changed

+106
-13
lines changed

3 files changed

+106
-13
lines changed

benches/misc.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ fn misc_gen_bool_var(b: &mut Bencher) {
3636
})
3737
}
3838

39+
#[bench]
40+
fn misc_gen_ratio_const(b: &mut Bencher) {
41+
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
42+
b.iter(|| {
43+
let mut accum = true;
44+
for _ in 0..::RAND_BENCH_N {
45+
accum ^= rng.gen_ratio(2, 3);
46+
}
47+
accum
48+
})
49+
}
50+
51+
#[bench]
52+
fn misc_gen_ratio_var(b: &mut Bencher) {
53+
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
54+
b.iter(|| {
55+
let mut accum = true;
56+
for i in 2..(::RAND_BENCH_N as u32 + 2) {
57+
accum ^= rng.gen_ratio(i, i + 1);
58+
}
59+
accum
60+
})
61+
}
62+
3963
#[bench]
4064
fn misc_bernoulli_const(b: &mut Bencher) {
4165
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();

src/distributions/bernoulli.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,28 @@ impl Bernoulli {
6767
};
6868
Bernoulli { p_int }
6969
}
70+
71+
/// Construct a new `Bernoulli` with the probability of success of
72+
/// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return
73+
/// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`.
74+
///
75+
/// If `numerator == denominator` then the returned `Bernoulli` will always
76+
/// return `true`. If `numerator == 0` it will always return `false`.
77+
///
78+
/// # Panics
79+
///
80+
/// If `denominator == 0` or `numerator > denominator`.
81+
///
82+
#[inline]
83+
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
84+
assert!(numerator <= denominator);
85+
if numerator == denominator {
86+
return Bernoulli { p_int: ::core::u64::MAX }
87+
}
88+
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
89+
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
90+
Bernoulli { p_int }
91+
}
7092
}
7193

7294
impl Distribution<bool> for Bernoulli {
@@ -103,18 +125,27 @@ mod test {
103125
#[test]
104126
fn test_average() {
105127
const P: f64 = 0.3;
106-
let d = Bernoulli::new(P);
107-
const N: u32 = 10_000_000;
128+
const NUM: u32 = 3;
129+
const DENOM: u32 = 10;
130+
let d1 = Bernoulli::new(P);
131+
let d2 = Bernoulli::from_ratio(NUM, DENOM);
132+
const N: u32 = 100_000;
108133

109-
let mut sum: u32 = 0;
134+
let mut sum1: u32 = 0;
135+
let mut sum2: u32 = 0;
110136
let mut rng = ::test::rng(2);
111137
for _ in 0..N {
112-
if d.sample(&mut rng) {
113-
sum += 1;
138+
if d1.sample(&mut rng) {
139+
sum1 += 1;
140+
}
141+
if d2.sample(&mut rng) {
142+
sum2 += 1;
114143
}
115144
}
116-
let avg = (sum as f64) / (N as f64);
145+
let avg1 = (sum1 as f64) / (N as f64);
146+
assert!((avg1 - P).abs() < 5e-3);
117147

118-
assert!((avg - P).abs() < 1e-3);
148+
let avg2 = (sum2 as f64) / (N as f64);
149+
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
119150
}
120151
}

src/lib.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ pub trait Rng: RngCore {
401401
/// ```
402402
///
403403
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
404-
fn gen_range<T: PartialOrd + SampleUniform>(&mut self, low: T, high: T) -> T {
404+
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
405405
T::Sampler::sample_single(low, high, self)
406406
}
407407

@@ -523,8 +523,6 @@ pub trait Rng: RngCore {
523523

524524
/// Return a bool with a probability `p` of being true.
525525
///
526-
/// This is a wrapper around [`distributions::Bernoulli`].
527-
///
528526
/// # Example
529527
///
530528
/// ```
@@ -536,15 +534,38 @@ pub trait Rng: RngCore {
536534
///
537535
/// # Panics
538536
///
539-
/// If `p` < 0 or `p` > 1.
540-
///
541-
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
537+
/// If `p < 0` or `p > 1`.
542538
#[inline]
543539
fn gen_bool(&mut self, p: f64) -> bool {
544540
let d = distributions::Bernoulli::new(p);
545541
self.sample(d)
546542
}
547543

544+
/// Return a bool with a probability of `numerator/denominator` of being
545+
/// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of
546+
/// returning true. If `numerator == denominator`, then the returned value
547+
/// is guaranteed to be `true`. If `numerator == 0`, then the returned
548+
/// value is guaranteed to be `false`.
549+
///
550+
/// # Panics
551+
///
552+
/// If `denominator == 0` or `numerator > denominator`.
553+
///
554+
/// # Example
555+
///
556+
/// ```
557+
/// use rand::{thread_rng, Rng};
558+
///
559+
/// let mut rng = thread_rng();
560+
/// println!("{}", rng.gen_ratio(2, 3));
561+
/// ```
562+
///
563+
#[inline]
564+
fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool {
565+
let d = distributions::Bernoulli::from_ratio(numerator, denominator);
566+
self.sample(d)
567+
}
568+
548569
/// Return a random element from `values`.
549570
///
550571
/// Return `None` if `values` is empty.
@@ -1196,4 +1217,21 @@ mod test {
11961217
(u8, i8, u16, i16, u32, i32, u64, i64),
11971218
(f32, (f64, (f64,)))) = random();
11981219
}
1220+
1221+
#[test]
1222+
fn test_gen_ratio_average() {
1223+
const NUM: u32 = 3;
1224+
const DENOM: u32 = 10;
1225+
const N: u32 = 100_000;
1226+
1227+
let mut sum: u32 = 0;
1228+
let mut rng = rng(111);
1229+
for _ in 0..N {
1230+
if rng.gen_ratio(NUM, DENOM) {
1231+
sum += 1;
1232+
}
1233+
}
1234+
let avg = (sum as f64) / (N as f64);
1235+
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
1236+
}
11991237
}

0 commit comments

Comments
 (0)