Skip to content

Commit dba18b8

Browse files
authored
Merge pull request #361 from pitdicker/sample_iter
Add an iterator to `Distribution`
2 parents fbf9572 + 4eb0831 commit dba18b8

File tree

4 files changed

+167
-32
lines changed

4 files changed

+167
-32
lines changed

benches/distributions.rs

+16
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,19 @@ gen_range_int!(gen_range_i32, i32, -200_000_000i32, 800_000_000);
134134
gen_range_int!(gen_range_i64, i64, 3i64, 123_456_789_123);
135135
#[cfg(feature = "i128_support")]
136136
gen_range_int!(gen_range_i128, i128, -12345678901234i128, 123_456_789_123_456_789);
137+
138+
#[bench]
139+
fn dist_iter(b: &mut Bencher) {
140+
let mut rng = XorShiftRng::new();
141+
let distr = Normal::new(-2.71828, 3.14159);
142+
let mut iter = distr.sample_iter(&mut rng);
143+
144+
b.iter(|| {
145+
let mut accum = 0.0;
146+
for _ in 0..::RAND_BENCH_N {
147+
accum += iter.next().unwrap();
148+
}
149+
black_box(accum);
150+
});
151+
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
152+
}

benches/misc.rs

+44
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,47 @@ macro_rules! sample_indices {
8888
sample_indices!(misc_sample_indices_10_of_1k, 10, 1000);
8989
sample_indices!(misc_sample_indices_50_of_1k, 50, 1000);
9090
sample_indices!(misc_sample_indices_100_of_1k, 100, 1000);
91+
92+
#[bench]
93+
fn gen_1k_iter_repeat(b: &mut Bencher) {
94+
use std::iter;
95+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
96+
b.iter(|| {
97+
let v: Vec<u64> = iter::repeat(()).map(|()| rng.gen()).take(128).collect();
98+
black_box(v);
99+
});
100+
b.bytes = 1024;
101+
}
102+
103+
#[bench]
104+
#[allow(deprecated)]
105+
fn gen_1k_gen_iter(b: &mut Bencher) {
106+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
107+
b.iter(|| {
108+
let v: Vec<u64> = rng.gen_iter().take(128).collect();
109+
black_box(v);
110+
});
111+
b.bytes = 1024;
112+
}
113+
114+
#[bench]
115+
fn gen_1k_sample_iter(b: &mut Bencher) {
116+
use rand::distributions::{Distribution, Uniform};
117+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
118+
b.iter(|| {
119+
let v: Vec<u64> = Uniform.sample_iter(&mut rng).take(128).collect();
120+
black_box(v);
121+
});
122+
b.bytes = 1024;
123+
}
124+
125+
#[bench]
126+
fn gen_1k_fill(b: &mut Bencher) {
127+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
128+
let mut buf = [0u64; 128];
129+
b.iter(|| {
130+
rng.fill(&mut buf[..]);
131+
black_box(buf);
132+
});
133+
b.bytes = 1024;
134+
}

src/distributions/mod.rs

+71-2
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,68 @@ mod impls {
130130

131131
/// Types (distributions) that can be used to create a random instance of `T`.
132132
pub trait Distribution<T> {
133-
/// Generate a random value of `T`, using `rng` as the
134-
/// source of randomness.
133+
/// Generate a random value of `T`, using `rng` as the source of randomness.
135134
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T;
135+
136+
/// Create an iterator that generates random values of `T`, using `rng` as
137+
/// the source of randomness.
138+
///
139+
/// # Example
140+
///
141+
/// ```rust
142+
/// use rand::thread_rng;
143+
/// use rand::distributions::{Distribution, Alphanumeric, Range, Uniform};
144+
///
145+
/// let mut rng = thread_rng();
146+
///
147+
/// // Vec of 16 x f32:
148+
/// let v: Vec<f32> = Uniform.sample_iter(&mut rng).take(16).collect();
149+
///
150+
/// // String:
151+
/// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect();
152+
///
153+
/// // Dice-rolling:
154+
/// let die_range = Range::new_inclusive(1, 6);
155+
/// let mut roll_die = die_range.sample_iter(&mut rng);
156+
/// while roll_die.next().unwrap() != 6 {
157+
/// println!("Not a 6; rolling again!");
158+
/// }
159+
/// ```
160+
fn sample_iter<'a, R: Rng>(&'a self, rng: &'a mut R)
161+
-> DistIter<'a, Self, R, T> where Self: Sized
162+
{
163+
DistIter {
164+
distr: self,
165+
rng: rng,
166+
phantom: ::core::marker::PhantomData,
167+
}
168+
}
169+
}
170+
171+
/// An iterator that generates random values of `T` with distribution `D`,
172+
/// using `R` as the source of randomness.
173+
///
174+
/// This `struct` is created by the [`sample_iter`] method on [`Distribution`].
175+
/// See its documentation for more.
176+
///
177+
/// [`Distribution`]: trait.Distribution.html
178+
/// [`sample_iter`]: trait.Distribution.html#method.sample_iter
179+
#[derive(Debug)]
180+
pub struct DistIter<'a, D, R, T> where D: Distribution<T> + 'a, R: Rng + 'a {
181+
distr: &'a D,
182+
rng: &'a mut R,
183+
phantom: ::core::marker::PhantomData<T>,
184+
}
185+
186+
impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
187+
where D: Distribution<T>, R: Rng + 'a
188+
{
189+
type Item = T;
190+
191+
#[inline(always)]
192+
fn next(&mut self) -> Option<T> {
193+
Some(self.distr.sample(self.rng))
194+
}
136195
}
137196

138197
impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
@@ -519,4 +578,14 @@ mod tests {
519578
let sampler = Exp::new(1.0);
520579
sampler.ind_sample(&mut ::test::rng(235));
521580
}
581+
582+
#[cfg(feature="std")]
583+
#[test]
584+
fn test_distributions_iter() {
585+
use distributions::Normal;
586+
let mut rng = ::test::rng(210);
587+
let distr = Normal::new(10.0, 10.0);
588+
let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect();
589+
println!("{:?}", results);
590+
}
522591
}

src/lib.rs

+36-30
Original file line numberDiff line numberDiff line change
@@ -300,33 +300,6 @@ pub trait Rand : Sized {
300300
/// }
301301
/// ```
302302
///
303-
/// # Iteration
304-
///
305-
/// Iteration over an `Rng` can be achieved using `iter::repeat` as follows:
306-
///
307-
/// ```rust
308-
/// use std::iter;
309-
/// use rand::{Rng, thread_rng};
310-
/// use rand::distributions::{Alphanumeric, Range};
311-
///
312-
/// let mut rng = thread_rng();
313-
///
314-
/// // Vec of 16 x f32:
315-
/// let v: Vec<f32> = iter::repeat(()).map(|()| rng.gen()).take(16).collect();
316-
///
317-
/// // String:
318-
/// let s: String = iter::repeat(())
319-
/// .map(|()| rng.sample(Alphanumeric))
320-
/// .take(7).collect();
321-
///
322-
/// // Dice-rolling:
323-
/// let die_range = Range::new_inclusive(1, 6);
324-
/// let mut roll_die = iter::repeat(()).map(|()| rng.sample(die_range));
325-
/// while roll_die.next().unwrap() != 6 {
326-
/// println!("Not a 6; rolling again!");
327-
/// }
328-
/// ```
329-
///
330303
/// [`RngCore`]: https://docs.rs/rand_core/0.1/rand_core/trait.RngCore.html
331304
pub trait Rng: RngCore {
332305
/// Fill `dest` entirely with random bytes (uniform value distribution),
@@ -408,6 +381,39 @@ pub trait Rng: RngCore {
408381
fn sample<T, D: Distribution<T>>(&mut self, distr: D) -> T {
409382
distr.sample(self)
410383
}
384+
385+
/// Create an iterator that generates values using the given distribution.
386+
///
387+
/// # Example
388+
///
389+
/// ```rust
390+
/// use rand::{thread_rng, Rng};
391+
/// use rand::distributions::{Alphanumeric, Range, Uniform};
392+
///
393+
/// let mut rng = thread_rng();
394+
///
395+
/// // Vec of 16 x f32:
396+
/// let v: Vec<f32> = thread_rng().sample_iter(&Uniform).take(16).collect();
397+
///
398+
/// // String:
399+
/// let s: String = rng.sample_iter(&Alphanumeric).take(7).collect();
400+
///
401+
/// // Combined values
402+
/// println!("{:?}", thread_rng().sample_iter(&Uniform).take(5)
403+
/// .collect::<Vec<(f64, bool)>>());
404+
///
405+
/// // Dice-rolling:
406+
/// let die_range = Range::new_inclusive(1, 6);
407+
/// let mut roll_die = rng.sample_iter(&die_range);
408+
/// while roll_die.next().unwrap() != 6 {
409+
/// println!("Not a 6; rolling again!");
410+
/// }
411+
/// ```
412+
fn sample_iter<'a, T, D: Distribution<T>>(&'a mut self, distr: &'a D)
413+
-> distributions::DistIter<'a, D, Self, T> where Self: Sized
414+
{
415+
distr.sample_iter(self)
416+
}
411417

412418
/// Return a random value supporting the [`Uniform`] distribution.
413419
///
@@ -443,7 +449,7 @@ pub trait Rng: RngCore {
443449
/// .collect::<Vec<(f64, bool)>>());
444450
/// ```
445451
#[allow(deprecated)]
446-
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
452+
#[deprecated(since="0.5.0", note="use Rng::sample_iter(&Uniform) instead")]
447453
fn gen_iter<T>(&mut self) -> Generator<T, &mut Self> where Uniform: Distribution<T> {
448454
Generator { rng: self, _marker: marker::PhantomData }
449455
}
@@ -528,7 +534,7 @@ pub trait Rng: RngCore {
528534
/// println!("{}", s);
529535
/// ```
530536
#[allow(deprecated)]
531-
#[deprecated(since="0.5.0", note="use distributions::Alphanumeric instead")]
537+
#[deprecated(since="0.5.0", note="use sample_iter(&Alphanumeric) instead")]
532538
fn gen_ascii_chars(&mut self) -> AsciiGenerator<&mut Self> {
533539
AsciiGenerator { rng: self }
534540
}
@@ -694,7 +700,7 @@ impl_as_byte_slice_arrays!(!div 4096, N,N,N,N,N,N,N,);
694700
/// [`Rng`]: trait.Rng.html
695701
#[derive(Debug)]
696702
#[allow(deprecated)]
697-
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
703+
#[deprecated(since="0.5.0", note="use Rng::sample_iter instead")]
698704
pub struct Generator<T, R: RngCore> {
699705
rng: R,
700706
_marker: marker::PhantomData<fn() -> T>,

0 commit comments

Comments
 (0)