Skip to content

Commit 54f4580

Browse files
dhardysicking
authored andcommitted
Float sampling: improve high precision sampling; add mean test
(The mean test is totally inadequate for checking high precision.)
1 parent 7c8b1e5 commit 54f4580

File tree

1 file changed

+76
-49
lines changed

1 file changed

+76
-49
lines changed

src/distributions/float.rs

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
//! Basic floating-point number distributions
1212
13-
use core::mem;
13+
use core::{cmp, mem};
1414
use Rng;
1515
use distributions::{Distribution, Standard};
1616
use distributions::utils::CastFromInt;
@@ -98,22 +98,20 @@ impl<F: HPFloatHelper> HighPrecision<F> {
9898
}
9999

100100
/// Generate a floating point number in the half-open interval `[0, 1)` with a
101-
/// uniform distribution.
101+
/// uniform distribution, with as much precision as the floating-point type
102+
/// can represent, including sub-normals.
102103
///
103-
/// This is different from `Uniform` in that it uses all 32 bits of an RNG for a
104-
/// `f32`, instead of only 23, the number of bits that fit in a floats fraction
105-
/// (or 64 instead of 52 bits for a `f64`).
104+
/// Technically 0 is representable, but the probability of occurrence is
105+
/// remote (1 in 2^149 for `f32` or 1 in 2^1074 for `f64`).
106106
///
107-
/// The smallest interval between values that can be generated is 2^-32
108-
/// (2.3283064e-10) for `f32`, and 2^-64 (5.421010862427522e-20) for `f64`.
109-
/// But this interval increases further away from zero because of limitations of
110-
/// the floating point format. Close to 1.0 the interval is 2^-24 (5.9604645e-8)
111-
/// for `f32`, and 2^-53 (1.1102230246251565) for `f64`. Compare this with
112-
/// `Uniform`, which has a fixed interval of 2^23 and 2^-52 respectively.
113-
///
114-
/// Note: in the future this may change change to request even more bits from
115-
/// the RNG if the value gets very close to 0.0, so it always has as many digits
116-
/// of precision as the float can represent.
107+
/// This is different from `Uniform` in that it uses as many random bits as
108+
/// required to get high precision close to 0. Normally only a single call to
109+
/// the source RNG is required (32 bits for `f32` or 64 bits for `f64`); 1 in
110+
/// 2^9 (`f32`) or 2^12 (`f64`) samples need an extra call; of these 1 in 2^32
111+
/// or 1 in 2^64 require a third call, etc.; i.e. even for `f32` a third call is
112+
/// almost impossible to observe with an unbiased RNG. Due to the extra logic
113+
/// there is some performance overhead relative to `Uniform`; this is more
114+
/// significant for `f32` than for `f64`.
117115
///
118116
/// # Example
119117
/// ```rust
@@ -231,24 +229,10 @@ float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
231229

232230

233231
macro_rules! high_precision_float_impls {
234-
($ty:ty, $uty:ty, $ity:ty, $fraction_bits:expr, $exponent_bits:expr) => {
232+
($ty:ty, $uty:ty, $ity:ty, $fraction_bits:expr, $exponent_bits:expr, $exponent_bias:expr) => {
235233
impl Distribution<$ty> for HighPrecision01 {
236234
/// Generate a floating point number in the half-open interval
237-
/// `[0, 1)` with a uniform distribution.
238-
///
239-
/// This is different from `Uniform` in that it uses all 32 bits
240-
/// of an RNG for a `f32`, instead of only 23, the number of bits
241-
/// that fit in a floats fraction (or 64 instead of 52 bits for a
242-
/// `f64`).
243-
///
244-
/// # Example
245-
/// ```rust
246-
/// use rand::{NewRng, SmallRng, Rng};
247-
/// use rand::distributions::HighPrecision01;
248-
///
249-
/// let val: f32 = SmallRng::new().sample(HighPrecision01);
250-
/// println!("f32 from [0,1): {}", val);
251-
/// ```
235+
/// `[0, 1)` with a uniform distribution. See [`HighPrecision01`].
252236
///
253237
/// # Algorithm
254238
/// (Note: this description used values that apply to `f32` to
@@ -257,34 +241,50 @@ macro_rules! high_precision_float_impls {
257241
/// The trick to generate a uniform distribution over [0,1) is to
258242
/// set the exponent to the -log2 of the remaining random bits. A
259243
/// simpler alternative to -log2 is to count the number of trailing
260-
/// zero's of the random bits.
244+
/// zeros in the random bits. In the case where all bits are zero,
245+
/// we simply generate a new random number and add the number of
246+
/// trailing zeros to the previous count (up to maximum exponent).
261247
///
262248
/// Each exponent is responsible for a piece of the distribution
263-
/// between [0,1). The exponent -1 fills the part [0.5,1). -2 fills
264-
/// [0.25,0.5). The lowest exponent we can get is -10. So a problem
265-
/// with this method is that we can not fill the part between zero
266-
/// and the part from -10. The solution is to treat numbers with an
267-
/// exponent of -10 as if they have -9 as exponent, and substract
268-
/// 2^-9 (implemented in the `fallback` function).
249+
/// between [0,1). We take the above exponent, add 1 and negate;
250+
/// thus with probability 1/2 we have exponent -1 which fills the
251+
/// range [0.5,1); with probability 1/4 we have exponent -2 which
252+
/// fills the range [0.25,0.5), etc. If the exponent reaches the
253+
/// minimum allowed, the floating-point format drops the implied
254+
/// fraction bit, thus allowing numbers down to 0 to be sampled.
255+
///
256+
/// [`HighPrecision01`]: struct.HighPrecision01.html
269257
#[inline]
270258
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
259+
// Unusual case. Separate function to allow inlining of rest.
271260
#[inline(never)]
272-
fn fallback(fraction: $uty) -> $ty {
273-
let float_size = (mem::size_of::<$ty>() * 8) as i32;
274-
let min_exponent = $fraction_bits as i32 - float_size;
275-
let adjust = // 2^MIN_EXPONENT
276-
(0 as $uty).into_float_with_exponent(min_exponent);
277-
fraction.into_float_with_exponent(min_exponent) - adjust
261+
fn fallback<R: Rng + ?Sized>(mut exp: i32, fraction: $uty, rng: &mut R) -> $ty {
262+
// Performance impact of code here is negligible.
263+
let bits = rng.gen::<$uty>();
264+
exp += bits.trailing_zeros() as i32;
265+
// If RNG were guaranteed unbiased we could skip the
266+
// check against exp; unfortunately it may be.
267+
// Worst case ("zeros" RNG) has recursion depth 16.
268+
if bits == 0 && exp < $exponent_bias {
269+
return fallback(exp, fraction, rng);
270+
}
271+
exp = cmp::min(exp, $exponent_bias);
272+
fraction.into_float_with_exponent(-exp)
278273
}
279274

280275
let fraction_mask = (1 << $fraction_bits) - 1;
281276
let value: $uty = rng.gen();
282277

283278
let fraction = value & fraction_mask;
284279
let remaining = value >> $fraction_bits;
285-
// If `remaing ==0` we end up in the lowest exponent, which
286-
// needs special treatment.
287-
if remaining == 0 { return fallback(fraction) }
280+
if remaining == 0 {
281+
// exp is compile-time constant so this reduces to a function call:
282+
let size_bits = (mem::size_of::<$ty>() * 8) as i32;
283+
let exp = (size_bits - $fraction_bits as i32) + 1;
284+
return fallback(exp, fraction, rng);
285+
}
286+
287+
// Usual case: exponent from -1 to -9 (f32) or -12 (f64)
288288
let exp = remaining.trailing_zeros() as i32 + 1;
289289
fraction.into_float_with_exponent(-exp)
290290
}
@@ -446,8 +446,8 @@ macro_rules! high_precision_float_impls {
446446
}
447447
}
448448

449-
high_precision_float_impls! { f32, u32, i32, 23, 8 }
450-
high_precision_float_impls! { f64, u64, i64, 52, 11 }
449+
high_precision_float_impls! { f32, u32, i32, 23, 8, 127 }
450+
high_precision_float_impls! { f64, u64, i64, 52, 11, 1023 }
451451

452452

453453
#[cfg(test)]
@@ -731,4 +731,31 @@ mod tests {
731731
assert_eq!(ones.sample::<f32, _>(HighPrecision01), 0.99999994);
732732
assert_eq!(ones.sample::<f64, _>(HighPrecision01), 0.9999999999999999);
733733
}
734+
735+
#[cfg(feature="std")] mod mean {
736+
use Rng;
737+
use distributions::{Standard, HighPrecision01};
738+
739+
macro_rules! test_mean {
740+
($name:ident, $ty:ty, $distr:expr) => {
741+
#[test]
742+
fn $name() {
743+
// TODO: no need to &mut here:
744+
let mut r = ::test::rng(602);
745+
let mut total: $ty = 0.0;
746+
const N: u32 = 1_000_000;
747+
for _ in 0..N {
748+
total += r.sample::<$ty, _>($distr);
749+
}
750+
let avg = total / (N as $ty);
751+
//println!("average over {} samples: {}", N, avg);
752+
assert!(0.499 < avg && avg < 0.501);
753+
}
754+
} }
755+
756+
test_mean!(test_mean_f32, f32, Standard);
757+
test_mean!(test_mean_f64, f64, Standard);
758+
test_mean!(test_mean_high_f32, f32, HighPrecision01);
759+
test_mean!(test_mean_high_f64, f64, HighPrecision01);
760+
}
734761
}

0 commit comments

Comments
 (0)