Skip to content

Commit e196960

Browse files
authored
Merge pull request rust-random#798 from vks/unit-ball
Implement unit ball sampling and rename UnitSphereSurface -> UnitSphere
2 parents bbd8ea4 + a5194aa commit e196960

File tree

6 files changed

+163
-19
lines changed

6 files changed

+163
-19
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ appveyor = { repository = "rust-random/rand" }
2020

2121
[dependencies]
2222
rand = { path = "..", version = ">=0.5, <=0.7" }
23+
24+
[dev-dependencies]
25+
# Histogram implementation for testing uniformity
26+
average = "0.9.2"

src/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,18 @@
5555
//! - [`Triangular`] distribution
5656
//! - Multivariate probability distributions
5757
//! - [`Dirichlet`] distribution
58-
//! - [`UnitSphereSurface`] distribution
58+
//! - [`UnitSphere`] distribution
59+
//! - [`UnitBall`] distribution
5960
//! - [`UnitCircle`] distribution
61+
//! - [`UnitDisc`] distribution
6062
6163
pub use rand::distributions::{Distribution, DistIter, Standard,
6264
Alphanumeric, Uniform, OpenClosed01, Open01, Bernoulli, uniform, weighted};
6365

64-
pub use self::unit_sphere::UnitSphereSurface;
66+
pub use self::unit_sphere::UnitSphere;
67+
pub use self::unit_ball::UnitBall;
6568
pub use self::unit_circle::UnitCircle;
69+
pub use self::unit_disc::UnitDisc;
6670
pub use self::gamma::{Gamma, Error as GammaError, ChiSquared, ChiSquaredError,
6771
FisherF, FisherFError, StudentT, Beta, BetaError};
6872
pub use self::normal::{Normal, Error as NormalError, LogNormal, StandardNormal};
@@ -78,7 +82,9 @@ pub use self::weibull::{Weibull, Error as WeibullError};
7882
pub use self::utils::Float;
7983

8084
mod unit_sphere;
85+
mod unit_ball;
8186
mod unit_circle;
87+
mod unit_disc;
8288
mod gamma;
8389
mod normal;
8490
mod exponential;

src/unit_ball.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright 2019 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use rand::Rng;
10+
use crate::{Distribution, Uniform, uniform::SampleUniform};
11+
use crate::utils::Float;
12+
13+
/// Samples uniformly from the unit ball (surface and interior) in three
14+
/// dimensions.
15+
///
16+
/// Implemented via rejection sampling.
17+
///
18+
///
19+
/// # Example
20+
///
21+
/// ```
22+
/// use rand_distr::{UnitBall, Distribution};
23+
///
24+
/// let v: [f64; 3] = UnitBall.sample(&mut rand::thread_rng());
25+
/// println!("{:?} is from the unit ball.", v)
26+
/// ```
27+
#[derive(Clone, Copy, Debug)]
28+
pub struct UnitBall;
29+
30+
impl<N: Float + SampleUniform> Distribution<[N; 3]> for UnitBall {
31+
#[inline]
32+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [N; 3] {
33+
let uniform = Uniform::new(N::from(-1.), N::from(1.));
34+
let mut x1;
35+
let mut x2;
36+
let mut x3;
37+
loop {
38+
x1 = uniform.sample(rng);
39+
x2 = uniform.sample(rng);
40+
x3 = uniform.sample(rng);
41+
if x1*x1 + x2*x2 + x3*x3 <= N::from(1.) {
42+
break;
43+
}
44+
}
45+
[x1, x2, x3]
46+
}
47+
}
48+
49+
#[cfg(test)]
50+
mod tests {
51+
use crate::Distribution;
52+
use super::UnitBall;
53+
54+
#[test]
55+
fn value_stability() {
56+
let mut rng = crate::test::rng(2);
57+
let expected = [
58+
[-0.42140140089381806, 0.4245276448803281, -0.7109276652167549],
59+
[0.6683277779168173, 0.12753134283863998, 0.6843687153674809],
60+
[-0.80397712218568, -0.0015797354643116712, 0.1588400395442835],
61+
];
62+
let samples: [[f64; 3]; 3] = [
63+
UnitBall.sample(&mut rng),
64+
UnitBall.sample(&mut rng),
65+
UnitBall.sample(&mut rng),
66+
];
67+
assert_eq!(samples, expected);
68+
}
69+
}

src/unit_disc.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2019 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use rand::Rng;
10+
use crate::{Distribution, Uniform, uniform::SampleUniform};
11+
use crate::utils::Float;
12+
13+
/// Samples uniformly from the unit disc in two dimensions.
14+
///
15+
/// Implemented via rejection sampling.
16+
///
17+
///
18+
/// # Example
19+
///
20+
/// ```
21+
/// use rand_distr::{UnitDisc, Distribution};
22+
///
23+
/// let v: [f64; 2] = UnitDisc.sample(&mut rand::thread_rng());
24+
/// println!("{:?} is from the unit Disc.", v)
25+
/// ```
26+
#[derive(Clone, Copy, Debug)]
27+
pub struct UnitDisc;
28+
29+
impl<N: Float + SampleUniform> Distribution<[N; 2]> for UnitDisc {
30+
#[inline]
31+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [N; 2] {
32+
let uniform = Uniform::new(N::from(-1.), N::from(1.));
33+
let mut x1;
34+
let mut x2;
35+
loop {
36+
x1 = uniform.sample(rng);
37+
x2 = uniform.sample(rng);
38+
if x1*x1 + x2*x2 <= N::from(1.) {
39+
break;
40+
}
41+
}
42+
[x1, x2]
43+
}
44+
}
45+
46+
#[cfg(test)]
47+
mod tests {
48+
use crate::Distribution;
49+
use super::UnitDisc;
50+
51+
#[test]
52+
fn value_stability() {
53+
let mut rng = crate::test::rng(2);
54+
let expected = [
55+
[-0.13921053103419823, -0.42140140089381806],
56+
[0.4245276448803281, -0.7109276652167549],
57+
[0.6683277779168173, 0.12753134283863998],
58+
];
59+
let samples: [[f64; 2]; 3] = [
60+
UnitDisc.sample(&mut rng),
61+
UnitDisc.sample(&mut rng),
62+
UnitDisc.sample(&mut rng),
63+
];
64+
assert_eq!(samples, expected);
65+
}
66+
}

src/unit_sphere.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018 Developers of the Rand project.
1+
// Copyright 2018-2019 Developers of the Rand project.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -18,19 +18,19 @@ use crate::utils::Float;
1818
/// # Example
1919
///
2020
/// ```
21-
/// use rand_distr::{UnitSphereSurface, Distribution};
21+
/// use rand_distr::{UnitSphere, Distribution};
2222
///
23-
/// let v: [f64; 3] = UnitSphereSurface.sample(&mut rand::thread_rng());
23+
/// let v: [f64; 3] = UnitSphere.sample(&mut rand::thread_rng());
2424
/// println!("{:?} is from the unit sphere surface.", v)
2525
/// ```
2626
///
2727
/// [^1]: Marsaglia, George (1972). [*Choosing a Point from the Surface of a
2828
/// Sphere.*](https://doi.org/10.1214/aoms/1177692644)
2929
/// Ann. Math. Statist. 43, no. 2, 645--646.
3030
#[derive(Clone, Copy, Debug)]
31-
pub struct UnitSphereSurface;
31+
pub struct UnitSphere;
3232

33-
impl<N: Float + SampleUniform> Distribution<[N; 3]> for UnitSphereSurface {
33+
impl<N: Float + SampleUniform> Distribution<[N; 3]> for UnitSphere {
3434
#[inline]
3535
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [N; 3] {
3636
let uniform = Uniform::new(N::from(-1.), N::from(1.));
@@ -49,7 +49,7 @@ impl<N: Float + SampleUniform> Distribution<[N; 3]> for UnitSphereSurface {
4949
#[cfg(test)]
5050
mod tests {
5151
use crate::Distribution;
52-
use super::UnitSphereSurface;
52+
use super::UnitSphere;
5353

5454
/// Assert that two numbers are almost equal to each other.
5555
///
@@ -71,7 +71,7 @@ mod tests {
7171
fn norm() {
7272
let mut rng = crate::test::rng(1);
7373
for _ in 0..1000 {
74-
let x: [f64; 3] = UnitSphereSurface.sample(&mut rng);
74+
let x: [f64; 3] = UnitSphere.sample(&mut rng);
7575
assert_almost_eq!(x[0]*x[0] + x[1]*x[1] + x[2]*x[2], 1., 1e-15);
7676
}
7777
}
@@ -85,9 +85,9 @@ mod tests {
8585
[0.9795722330927367, 0.18692349236651176, 0.07414747571708524],
8686
];
8787
let samples: [[f64; 3]; 3] = [
88-
UnitSphereSurface.sample(&mut rng),
89-
UnitSphereSurface.sample(&mut rng),
90-
UnitSphereSurface.sample(&mut rng),
88+
UnitSphere.sample(&mut rng),
89+
UnitSphere.sample(&mut rng),
90+
UnitSphere.sample(&mut rng),
9191
];
9292
assert_eq!(samples, expected);
9393
}

tests/uniformity.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
#![cfg(feature = "std")]
10-
119
#[macro_use]
1210
extern crate average;
1311
extern crate rand;
12+
extern crate rand_distr;
13+
extern crate core;
1414

1515
use average::Histogram;
1616
use rand::distributions::Distribution;
1717
use rand::FromEntropy;
18-
use std as core;
1918

2019
const N_BINS: usize = 100;
2120
const N_SAMPLES: u32 = 1_000_000;
@@ -28,10 +27,10 @@ fn unit_sphere() {
2827
const N_DIM: usize = 3;
2928
let h = Histogram100::with_const_width(-1., 1.);
3029
let mut histograms = [h.clone(), h.clone(), h];
31-
let dist = rand::distributions::UnitSphereSurface::new();
30+
let dist = rand_distr::UnitSphere;
3231
let mut rng = rand::rngs::SmallRng::from_entropy();
3332
for _ in 0..N_SAMPLES {
34-
let v = dist.sample(&mut rng);
33+
let v: [f64; 3] = dist.sample(&mut rng);
3534
for i in 0..N_DIM {
3635
histograms[i].add(v[i]).map_err(
3736
|e| { println!("v: {}", v[i]); e }
@@ -52,10 +51,10 @@ fn unit_sphere() {
5251
fn unit_circle() {
5352
use std::f64::consts::PI;
5453
let mut h = Histogram100::with_const_width(-PI, PI);
55-
let dist = rand::distributions::UnitCircle::new();
54+
let dist = rand_distr::UnitCircle;
5655
let mut rng = rand::rngs::SmallRng::from_entropy();
5756
for _ in 0..N_SAMPLES {
58-
let v = dist.sample(&mut rng);
57+
let v: [f64; 2] = dist.sample(&mut rng);
5958
h.add(v[0].atan2(v[1])).unwrap();
6059
}
6160
let sum: u64 = h.bins().iter().sum();

0 commit comments

Comments
 (0)