16
16
#[ cfg( feature="std" ) ] use std:: collections:: { HashSet } ;
17
17
#[ cfg( all( feature="alloc" , not( feature="std" ) ) ) ] use alloc:: collections:: BTreeSet ;
18
18
19
- #[ cfg( feature="alloc" ) ] use distributions:: { Distribution , Uniform } ;
19
+ #[ cfg( feature="alloc" ) ] use distributions:: { Distribution , Uniform , uniform :: SampleUniform } ;
20
20
use Rng ;
21
21
22
22
/// A vector of indices.
@@ -212,9 +212,7 @@ where R: Rng + ?Sized {
212
212
if ( length as f32 ) < C [ j] * ( amount as f32 ) {
213
213
sample_inplace ( rng, length, amount)
214
214
} else {
215
- // note: could have a specific u32 impl, but I'm lazy and
216
- // generics don't have usable conversions
217
- sample_rejection ( rng, length as usize , amount as usize )
215
+ sample_rejection ( rng, length, amount)
218
216
}
219
217
}
220
218
}
@@ -285,28 +283,44 @@ where R: Rng + ?Sized {
285
283
IndexVec :: from ( indices)
286
284
}
287
285
286
+ trait UInt : Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core:: hash:: Hash {
287
+ fn zero ( ) -> Self ;
288
+ fn as_usize ( self ) -> usize ;
289
+ }
290
+ impl UInt for u32 {
291
+ #[ inline] fn zero ( ) -> Self { 0 }
292
+ #[ inline] fn as_usize ( self ) -> usize { self as usize }
293
+ }
294
+ impl UInt for usize {
295
+ #[ inline] fn zero ( ) -> Self { 0 }
296
+ #[ inline] fn as_usize ( self ) -> usize { self }
297
+ }
298
+
288
299
/// Randomly sample exactly `amount` indices from `0..length`, using rejection
289
300
/// sampling.
290
301
///
291
302
/// Since `amount <<< length` there is a low chance of a random sample in
292
303
/// `0..length` being a duplicate. We test for duplicates and resample where
293
304
/// necessary. The algorithm is `O(amount)` time and memory.
294
- fn sample_rejection < R > ( rng : & mut R , length : usize , amount : usize ) -> IndexVec
295
- where R : Rng + ?Sized {
305
+ ///
306
+ /// This function is generic over X primarily so that results are value-stable
307
+ /// over 32-bit and 64-bit platforms.
308
+ fn sample_rejection < X : UInt , R > ( rng : & mut R , length : X , amount : X ) -> IndexVec
309
+ where R : Rng + ?Sized , IndexVec : From < Vec < X > > {
296
310
debug_assert ! ( amount < length) ;
297
- #[ cfg( feature="std" ) ] let mut cache = HashSet :: with_capacity ( amount) ;
311
+ #[ cfg( feature="std" ) ] let mut cache = HashSet :: with_capacity ( amount. as_usize ( ) ) ;
298
312
#[ cfg( not( feature="std" ) ) ] let mut cache = BTreeSet :: new ( ) ;
299
- let distr = Uniform :: new ( 0 , length) ;
300
- let mut indices = Vec :: with_capacity ( amount) ;
301
- for _ in 0 ..amount {
313
+ let distr = Uniform :: new ( X :: zero ( ) , length) ;
314
+ let mut indices = Vec :: with_capacity ( amount. as_usize ( ) ) ;
315
+ for _ in 0 ..amount. as_usize ( ) {
302
316
let mut pos = distr. sample ( rng) ;
303
317
while !cache. insert ( pos) {
304
318
pos = distr. sample ( rng) ;
305
319
}
306
320
indices. push ( pos) ;
307
321
}
308
322
309
- debug_assert_eq ! ( indices. len( ) , amount) ;
323
+ debug_assert_eq ! ( indices. len( ) , amount. as_usize ( ) ) ;
310
324
IndexVec :: from ( indices)
311
325
}
312
326
@@ -322,14 +336,14 @@ mod test {
322
336
assert_eq ! ( sample_inplace( & mut r, 1 , 0 ) . len( ) , 0 ) ;
323
337
assert_eq ! ( sample_inplace( & mut r, 1 , 1 ) . into_vec( ) , vec![ 0 ] ) ;
324
338
325
- assert_eq ! ( sample_rejection( & mut r, 1 , 0 ) . len( ) , 0 ) ;
339
+ assert_eq ! ( sample_rejection( & mut r, 1u32 , 0 ) . len( ) , 0 ) ;
326
340
327
341
assert_eq ! ( sample_floyd( & mut r, 0 , 0 ) . len( ) , 0 ) ;
328
342
assert_eq ! ( sample_floyd( & mut r, 1 , 0 ) . len( ) , 0 ) ;
329
343
assert_eq ! ( sample_floyd( & mut r, 1 , 1 ) . into_vec( ) , vec![ 0 ] ) ;
330
344
331
345
// These algorithms should be fast with big numbers. Test average.
332
- let sum: usize = sample_rejection ( & mut r, 1 << 25 , 10 )
346
+ let sum: usize = sample_rejection ( & mut r, 1 << 25 , 10u32 )
333
347
. into_iter ( ) . sum ( ) ;
334
348
assert ! ( 1 << 25 < sum && sum < ( 1 << 25 ) * 25 ) ;
335
349
@@ -368,7 +382,7 @@ mod test {
368
382
// A large length and larger amount should use cache
369
383
let ( length, amount) : ( usize , usize ) = ( 1 <<20 , 600 ) ;
370
384
let v1 = sample ( & mut seed_rng ( 422 ) , length, amount) ;
371
- let v2 = sample_rejection ( & mut seed_rng ( 422 ) , length, amount) ;
385
+ let v2 = sample_rejection ( & mut seed_rng ( 422 ) , length as u32 , amount as u32 ) ;
372
386
assert ! ( v1. iter( ) . all( |e| e < length) ) ;
373
387
assert_eq ! ( v1, v2) ;
374
388
}
0 commit comments