Skip to content

Commit 4a375f6

Browse files
committed
Make seq::index::sample_rejection generic over uint index types
1 parent b613749 commit 4a375f6

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

src/seq/index.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#[cfg(feature="std")] use std::collections::{HashSet};
1717
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeSet;
1818

19-
#[cfg(feature="alloc")] use distributions::{Distribution, Uniform};
19+
#[cfg(feature="alloc")] use distributions::{Distribution, Uniform, uniform::SampleUniform};
2020
use Rng;
2121

2222
/// A vector of indices.
@@ -212,9 +212,7 @@ where R: Rng + ?Sized {
212212
if (length as f32) < C[j] * (amount as f32) {
213213
sample_inplace(rng, length, amount)
214214
} else {
215-
// note: could have a specific u32 impl, but I'm lazy and
216-
// generics don't have usable conversions
217-
sample_rejection(rng, length as usize, amount as usize)
215+
sample_rejection(rng, length, amount)
218216
}
219217
}
220218
}
@@ -285,28 +283,44 @@ where R: Rng + ?Sized {
285283
IndexVec::from(indices)
286284
}
287285

286+
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
287+
fn zero() -> Self;
288+
fn as_usize(self) -> usize;
289+
}
290+
impl UInt for u32 {
291+
#[inline] fn zero() -> Self { 0 }
292+
#[inline] fn as_usize(self) -> usize { self as usize }
293+
}
294+
impl UInt for usize {
295+
#[inline] fn zero() -> Self { 0 }
296+
#[inline] fn as_usize(self) -> usize { self }
297+
}
298+
288299
/// Randomly sample exactly `amount` indices from `0..length`, using rejection
289300
/// sampling.
290301
///
291302
/// Since `amount <<< length` there is a low chance of a random sample in
292303
/// `0..length` being a duplicate. We test for duplicates and resample where
293304
/// necessary. The algorithm is `O(amount)` time and memory.
294-
fn sample_rejection<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
295-
where R: Rng + ?Sized {
305+
///
306+
/// This function is generic over X primarily so that results are value-stable
307+
/// over 32-bit and 64-bit platforms.
308+
fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
309+
where R: Rng + ?Sized, IndexVec: From<Vec<X>> {
296310
debug_assert!(amount < length);
297-
#[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount);
311+
#[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount.as_usize());
298312
#[cfg(not(feature="std"))] let mut cache = BTreeSet::new();
299-
let distr = Uniform::new(0, length);
300-
let mut indices = Vec::with_capacity(amount);
301-
for _ in 0..amount {
313+
let distr = Uniform::new(X::zero(), length);
314+
let mut indices = Vec::with_capacity(amount.as_usize());
315+
for _ in 0..amount.as_usize() {
302316
let mut pos = distr.sample(rng);
303317
while !cache.insert(pos) {
304318
pos = distr.sample(rng);
305319
}
306320
indices.push(pos);
307321
}
308322

309-
debug_assert_eq!(indices.len(), amount);
323+
debug_assert_eq!(indices.len(), amount.as_usize());
310324
IndexVec::from(indices)
311325
}
312326

@@ -322,14 +336,14 @@ mod test {
322336
assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0);
323337
assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]);
324338

325-
assert_eq!(sample_rejection(&mut r, 1, 0).len(), 0);
339+
assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0);
326340

327341
assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0);
328342
assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0);
329343
assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]);
330344

331345
// These algorithms should be fast with big numbers. Test average.
332-
let sum: usize = sample_rejection(&mut r, 1 << 25, 10)
346+
let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32)
333347
.into_iter().sum();
334348
assert!(1 << 25 < sum && sum < (1 << 25) * 25);
335349

@@ -368,7 +382,7 @@ mod test {
368382
// A large length and larger amount should use cache
369383
let (length, amount): (usize, usize) = (1<<20, 600);
370384
let v1 = sample(&mut seed_rng(422), length, amount);
371-
let v2 = sample_rejection(&mut seed_rng(422), length, amount);
385+
let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32);
372386
assert!(v1.iter().all(|e| e < length));
373387
assert_eq!(v1, v2);
374388
}

0 commit comments

Comments
 (0)