Skip to content

Commit 922bb09

Browse files
committed
Impl new choose_multiple functions
1 parent 18e7167 commit 922bb09

File tree

1 file changed

+107
-56
lines changed

1 file changed

+107
-56
lines changed

src/seq.rs

+107-56
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
//!
1313
//! TODO: module doc
1414
15+
use core::ops::Index;
16+
1517
use super::Rng;
1618

1719
// BTreeMap is not as fast in tests, but better than nothing.
1820
#[cfg(feature="std")] use std::collections::HashMap;
1921
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::btree_map::BTreeMap;
2022

21-
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec;
22-
23+
#[cfg(feature="std")] use std::vec;
24+
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::{vec, Vec};
2325

2426
/// Extension trait on slices, providing random mutation and sampling methods.
2527
///
@@ -69,7 +71,7 @@ pub trait SliceExt {
6971
///
7072
/// Complexity is expected to be the same as `sample_indices`.
7173
#[cfg(feature = "alloc")]
72-
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item>
74+
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
7375
where R: Rng + ?Sized;
7476

7577
/// Shuffle a mutable slice in place.
@@ -129,33 +131,78 @@ pub trait IteratorExt: Iterator + Sized {
129131

130132
/// Collects `amount` values at random from the iterator into a supplied
131133
/// buffer.
132-
///
134+
///
135+
/// Although the elements are selected randomly, the order of elements in
136+
/// the buffer is neither stable nor fully random. If random ordering is
137+
/// desired, shuffle the result.
138+
///
133139
/// Returns the number of elements added to the buffer. This equals `amount`
134140
/// unless the iterator contains insufficient elements, in which case this
135141
/// equals the number of elements available.
136142
///
137-
/// Complexity is TODO
138-
fn choose_multiple_fill<R>(self, rng: &mut R, amount: usize) -> usize
139-
where R: Rng + ?Sized
143+
/// Complexity is `O(n)` where `n` is the length of the iterator.
144+
fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item])
145+
-> usize where R: Rng + ?Sized
140146
{
141-
unimplemented!()
147+
let amount = buf.len();
148+
let mut len = 0;
149+
while len < amount {
150+
if let Some(elem) = self.next() {
151+
buf[len] = elem;
152+
len += 1;
153+
} else {
154+
// Iterator exhausted; stop early
155+
return len;
156+
}
157+
}
158+
159+
// Continue, since the iterator was not exhausted
160+
for (i, elem) in self.enumerate() {
161+
let k = rng.gen_range(0, i + 1 + amount);
162+
if k < amount {
163+
buf[k] = elem;
164+
}
165+
}
166+
len
142167
}
143168

144169
/// Collects `amount` values at random from the iterator into a vector.
145170
///
146-
/// This is a convenience wrapper around `choose_multiple_fill`.
171+
/// This is equivalent to `choose_multiple_fill` except for the result type.
147172
///
173+
/// Although the elements are selected randomly, the order of elements in
174+
/// the buffer is neither stable nor fully random. If random ordering is
175+
/// desired, shuffle the result.
176+
///
148177
/// The length of the returned vector equals `amount` unless the iterator
149178
/// contains insufficient elements, in which case it equals the number of
150179
/// elements available.
151180
///
152-
/// Complexity is TODO
181+
/// Complexity is `O(n)` where `n` is the length of the iterator.
153182
#[cfg(feature = "alloc")]
154-
fn choose_multiple<R>(self, rng: &mut R, amount: usize) -> Vec<Self::Item>
183+
fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
155184
where R: Rng + ?Sized
156185
{
157-
// Note: I think this must use unsafe to create an uninitialised buffer, then restrict length
158-
unimplemented!()
186+
let mut reservoir = Vec::with_capacity(amount);
187+
reservoir.extend(self.by_ref().take(amount));
188+
189+
// Continue unless the iterator was exhausted
190+
//
191+
// note: this prevents iterators that "restart" from causing problems.
192+
// If the iterator stops once, then so do we.
193+
if reservoir.len() == amount {
194+
for (i, elem) in self.enumerate() {
195+
let k = rng.gen_range(0, i + 1 + amount);
196+
if let Some(spot) = reservoir.get_mut(k) {
197+
*spot = elem;
198+
}
199+
}
200+
} else {
201+
// Don't hang onto extra memory. There is a corner case where
202+
// `amount` was much less than `self.len()`.
203+
reservoir.shrink_to_fit();
204+
}
205+
reservoir
159206
}
160207
}
161208

@@ -185,10 +232,15 @@ impl<T> SliceExt for [T] {
185232
}
186233

187234
#[cfg(feature = "alloc")]
188-
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item>
235+
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
189236
where R: Rng + ?Sized
190237
{
191-
unimplemented!()
238+
let amount = ::core::cmp::min(amount, self.len());
239+
SliceChooseIter {
240+
slice: self,
241+
_phantom: Default::default(),
242+
indices: sample_indices(rng, self.len(), amount).into_iter(),
243+
}
192244
}
193245

194246
fn shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized
@@ -209,57 +261,55 @@ impl<T> SliceExt for [T] {
209261
}
210262
}
211263

264+
impl<I> IteratorExt for I where I: Iterator + Sized {}
265+
266+
267+
/// Iterator over multiple choices, as returned by [`SliceExt::choose_multiple](
268+
/// trait.SliceExt.html#method.choose_multiple).
269+
#[cfg(feature = "alloc")]
270+
#[derive(Debug)]
271+
pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
272+
slice: &'a S,
273+
_phantom: ::core::marker::PhantomData<T>,
274+
indices: vec::IntoIter<usize>,
275+
}
276+
277+
#[cfg(feature = "alloc")]
278+
impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
279+
type Item = &'a T;
280+
281+
fn next(&mut self) -> Option<Self::Item> {
282+
self.indices.next().map(|i| &(*self.slice)[i])
283+
}
284+
285+
fn size_hint(&self) -> (usize, Option<usize>) {
286+
(self.indices.len(), Some(self.indices.len()))
287+
}
288+
}
289+
290+
212291
// ———
213-
// TODO: remove below methods once implemented above
214292
// TODO: also revise signature of `sample_indices`?
215293
// ———
216294

217295
/// Randomly sample `amount` elements from a finite iterator.
218296
///
219-
/// The following can be returned:
220-
///
221-
/// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
222-
/// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
223-
/// length of `iterable` was less than `amount`. This is considered an error since exactly
224-
/// `amount` elements is typically expected.
225-
///
226-
/// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
227-
///
228-
/// # Example
229-
///
230-
/// ```
231-
/// use rand::{thread_rng, seq};
232-
///
233-
/// let mut rng = thread_rng();
234-
/// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
235-
/// println!("{:?}", sample);
236-
/// ```
297+
/// Deprecated: use [`IteratorExt::choose_multiple`] instead.
298+
///
299+
/// [`IteratorExt::choose_multiple`]: trait.IteratorExt.html#method.choose_multiple
237300
#[cfg(feature = "alloc")]
301+
#[deprecated(since="0.6.0", note="use IteratorExt::choose_multiple instead")]
238302
pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
239303
where I: IntoIterator<Item=T>,
240304
R: Rng + ?Sized,
241305
{
242-
let mut iter = iterable.into_iter();
243-
let mut reservoir = Vec::with_capacity(amount);
244-
reservoir.extend(iter.by_ref().take(amount));
245-
246-
// Continue unless the iterator was exhausted
247-
//
248-
// note: this prevents iterators that "restart" from causing problems.
249-
// If the iterator stops once, then so do we.
250-
if reservoir.len() == amount {
251-
for (i, elem) in iter.enumerate() {
252-
let k = rng.gen_range(0, i + 1 + amount);
253-
if let Some(spot) = reservoir.get_mut(k) {
254-
*spot = elem;
255-
}
256-
}
257-
Ok(reservoir)
306+
use seq::IteratorExt;
307+
let iter = iterable.into_iter();
308+
let result = iter.choose_multiple(rng, amount);
309+
if result.len() == amount {
310+
Ok(result)
258311
} else {
259-
// Don't hang onto extra memory. There is a corner case where
260-
// `amount` was much less than `len(iterable)`.
261-
reservoir.shrink_to_fit();
262-
Err(reservoir)
312+
Err(result)
263313
}
264314
}
265315

@@ -426,6 +476,7 @@ fn sample_indices_cache<R>(
426476
#[cfg(test)]
427477
mod test {
428478
use super::*;
479+
use super::IteratorExt;
429480
#[cfg(feature = "alloc")]
430481
use {XorShiftRng, Rng, SeedableRng};
431482
#[cfg(all(feature="alloc", not(feature="std")))]
@@ -468,8 +519,8 @@ mod test {
468519

469520
let mut r = ::test::rng(401);
470521
let vals = (min_val..max_val).collect::<Vec<i32>>();
471-
let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
472-
let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
522+
let small_sample = vals.iter().choose_multiple(&mut r, 5);
523+
let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
473524

474525
assert_eq!(small_sample.len(), 5);
475526
assert_eq!(large_sample.len(), vals.len());

0 commit comments

Comments
 (0)