@@ -269,6 +269,7 @@ pub mod prelude;
269
269
pub mod prng;
270
270
pub mod rngs;
271
271
#[ cfg( feature = "alloc" ) ] pub mod seq;
272
+ #[ cfg( feature = "alloc" ) ] pub use seq:: { SliceRandom , IteratorRandom } ;
272
273
273
274
////////////////////////////////////////////////////////////////////////////////
274
275
// Compatibility re-exports. Documentation is hidden; will be removed eventually.
@@ -595,6 +596,161 @@ pub trait Rng: RngCore {
595
596
}
596
597
}
597
598
599
+ /// Returns one random element of the `Iterator`, or `None` if the
600
+ /// `Iterator` returns no items. If you have a slice, it's significantly
601
+ /// faster to call the [`choose`] or [`choose_mut`] functions using the
602
+ /// slice instead. However it expected to be faster than dumping the
603
+ /// Iterator into a slice and then calling [`choose`]/[`choose_mut`] on
604
+ /// the slice.
605
+ ///
606
+ /// # Example
607
+ ///
608
+ /// ```
609
+ /// use rand::{thread_rng, Rng};
610
+ ///
611
+ /// let choices = std::iter::repeat(0)
612
+ /// .scan((1, 1), |state, _| { let (a, b) = *state; *state = (b, a+b); Some(a) })
613
+ /// .take(40);
614
+ /// let mut rng = thread_rng();
615
+ /// // Randomly choose one of the first 40 fibonacci numbers
616
+ /// println!("{}", rng.choose_from_iterator(choices).unwrap());
617
+ /// assert_eq!(rng.choose_from_iterator(std::iter::empty::<i32>()), None);
618
+ /// ```
619
+ /// [`choose`]: trait.Rng.html#method.choose
620
+ /// [`choose_mut`]: trait.Rng.html#method.choose_mut
621
+ fn choose_from_iterator < I : Iterator > ( & mut self , mut iterable : I ) -> Option < I :: Item > {
622
+ let mut val = iterable. next ( ) ;
623
+ if val. is_none ( ) {
624
+ return val;
625
+ }
626
+
627
+ for ( i, elem) in iterable. enumerate ( ) {
628
+ if self . gen_range ( 0 , i + 2 ) == 0 {
629
+ val = Some ( elem) ;
630
+ }
631
+ }
632
+ val
633
+ }
634
+
635
+ /// Return a random element from `items` where. The chance of a given item
636
+ /// being picked, is proportional to the corresponding value in `weights`.
637
+ /// `weights` and `items` must return exactly the same number of values.
638
+ ///
639
+ /// All values returned by `weights` must be `>= 0`.
640
+ ///
641
+ /// This function iterates over `weights` twice. Once to get the total
642
+ /// weight, and once while choosing the random value. If you know the total
643
+ /// weight, or plan to call this function multiple times, you should
644
+ /// consider using [`choose_weighted_with_total`] instead.
645
+ ///
646
+ /// Return `None` if `items` and `weights` is empty.
647
+ ///
648
+ /// # Example
649
+ ///
650
+ /// ```
651
+ /// use rand::{thread_rng, Rng};
652
+ ///
653
+ /// let choices = ['a', 'b', 'c'];
654
+ /// let weights = [2, 1, 1];
655
+ /// let mut rng = thread_rng();
656
+ /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
657
+ /// println!("{}", rng.choose_weighted(choices.iter(), weights.iter().cloned()).unwrap());
658
+ /// ```
659
+ /// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
660
+ fn choose_weighted < IterItems , IterWeights > ( & mut self ,
661
+ items : IterItems ,
662
+ weights : IterWeights ) -> Option < IterItems :: Item >
663
+ where IterItems : Iterator ,
664
+ IterWeights : Iterator +Clone ,
665
+ IterWeights :: Item : SampleUniform +
666
+ Default +
667
+ core:: ops:: Add < IterWeights :: Item , Output =IterWeights :: Item > +
668
+ core:: cmp:: PartialOrd < IterWeights :: Item > +
669
+ Clone { // Clone is only needed for debug assertions
670
+ let total_weight: IterWeights :: Item =
671
+ weights. clone ( ) . fold ( Default :: default ( ) , |acc, w| {
672
+ assert ! ( w >= Default :: default ( ) , "Weight must be larger than zero" ) ;
673
+ acc + w
674
+ } ) ;
675
+ self . choose_weighted_with_total ( items, weights, total_weight)
676
+ }
677
+
678
+ /// Return a random element from `items` where. The chance of a given item
679
+ /// being picked, is proportional to the corresponding value in `weights`.
680
+ /// `weights` and `items` must return exactly the same number of values.
681
+ ///
682
+ /// All values returned by `weights` must be `>= 0`.
683
+ ///
684
+ /// `total_weight` must be exactly the sum of all values returned by
685
+ /// `weights`. Builds with debug_assertions turned on will assert that this
686
+ /// equality holds. Simply storing the result of `weights.sum()` and using
687
+ /// that as `total_weight` should work.
688
+ ///
689
+ /// Return `None` if `items` and `weights` is empty.
690
+ ///
691
+ /// # Example
692
+ ///
693
+ /// ```
694
+ /// use rand::{thread_rng, Rng};
695
+ ///
696
+ /// let choices = ['a', 'b', 'c'];
697
+ /// let weights = [2, 1, 1];
698
+ /// let mut rng = thread_rng();
699
+ /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
700
+ /// println!("{}", rng.choose_weighted_with_total(choices.iter(), weights.iter().cloned(), 4).unwrap());
701
+ /// ```
702
+ /// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
703
+ fn choose_weighted_with_total < IterItems , IterWeights > ( & mut self ,
704
+ mut items : IterItems ,
705
+ mut weights : IterWeights ,
706
+ total_weight : IterWeights :: Item ) -> Option < IterItems :: Item >
707
+ where IterItems : Iterator ,
708
+ IterWeights : Iterator ,
709
+ IterWeights :: Item : SampleUniform +
710
+ Default +
711
+ core:: ops:: Add < IterWeights :: Item , Output =IterWeights :: Item > +
712
+ core:: cmp:: PartialOrd < IterWeights :: Item > +
713
+ Clone { // Clone is only needed for debug assertions
714
+
715
+ if total_weight == Default :: default ( ) {
716
+ debug_assert ! ( items. next( ) . is_none( ) ) ;
717
+ return None ;
718
+ }
719
+
720
+ // Only used when debug_assertions are turned on
721
+ let mut debug_result = None ;
722
+ let debug_total_weight = if cfg ! ( debug_assertions) { Some ( total_weight. clone ( ) ) } else { None } ;
723
+
724
+ let chosen_weight = self . gen_range ( Default :: default ( ) , total_weight) ;
725
+ let mut cumulative_weight: IterWeights :: Item = Default :: default ( ) ;
726
+
727
+ for item in items {
728
+ let weight_opt = weights. next ( ) ;
729
+ assert ! ( weight_opt. is_some( ) , "`weights` returned fewer items than `items` did" ) ;
730
+ let weight = weight_opt. unwrap ( ) ;
731
+ assert ! ( weight >= Default :: default ( ) , "Weight must be larger than zero" ) ;
732
+
733
+ cumulative_weight = cumulative_weight + weight;
734
+
735
+ if cumulative_weight > chosen_weight {
736
+ if !cfg ! ( debug_assertions) {
737
+ return Some ( item) ;
738
+ }
739
+ if debug_result. is_none ( ) {
740
+ debug_result = Some ( item) ;
741
+ }
742
+ }
743
+ }
744
+
745
+ assert ! ( weights. next( ) . is_none( ) , "`weights` returned more items than `items` did" ) ;
746
+ debug_assert ! ( debug_total_weight. unwrap( ) == cumulative_weight) ;
747
+ if cfg ! ( debug_assertions) && debug_result. is_some ( ) {
748
+ return debug_result;
749
+ }
750
+
751
+ panic ! ( "total_weight did not match up with sum of weights" ) ;
752
+ }
753
+
598
754
/// Shuffle a mutable slice in place.
599
755
///
600
756
/// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle](
@@ -846,6 +1002,7 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
846
1002
#[ cfg( test) ]
847
1003
mod test {
848
1004
use rngs:: mock:: StepRng ;
1005
+ #[ cfg( feature="std" ) ] use core:: panic:: catch_unwind;
849
1006
use super :: * ;
850
1007
#[ cfg( all( not( feature="std" ) , feature="alloc" ) ) ] use alloc:: boxed:: Box ;
851
1008
@@ -976,15 +1133,50 @@ mod test {
976
1133
#[ test]
977
1134
fn test_choose ( ) {
978
1135
let mut r = rng ( 107 ) ;
979
- assert_eq ! ( r. choose( & [ 1 , 1 , 1 ] ) . map( |& x|x) , Some ( 1 ) ) ;
1136
+ let chars = [ 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' ] ;
1137
+ let mut chosen = [ 0i32 ; 14 ] ;
1138
+ for _ in 0 ..1000 {
1139
+ let picked = * r. choose ( & chars) . unwrap ( ) ;
1140
+ chosen[ ( picked as usize ) - ( 'a' as usize ) ] += 1 ;
1141
+ }
1142
+ for count in chosen. iter ( ) {
1143
+ let err = * count - ( 1000 / ( chars. len ( ) as i32 ) ) ;
1144
+ assert ! ( -20 <= err && err <= 20 ) ;
1145
+ }
980
1146
981
- let v: & [ isize ] = & [ ] ;
982
- assert_eq ! ( r. choose( v) , None ) ;
1147
+ chosen. iter_mut ( ) . for_each ( |x| * x = 0 ) ;
1148
+ for _ in 0 ..1000 {
1149
+ * r. choose_mut ( & mut chosen) . unwrap ( ) += 1 ;
1150
+ }
1151
+ for count in chosen. iter ( ) {
1152
+ let err = * count - ( 1000 / ( chosen. len ( ) as i32 ) ) ;
1153
+ assert ! ( -20 <= err && err <= 20 ) ;
1154
+ }
1155
+
1156
+ let mut v: [ isize ; 0 ] = [ ] ;
1157
+ assert_eq ! ( r. choose( & v) , None ) ;
1158
+ assert_eq ! ( r. choose_mut( & mut v) , None ) ;
983
1159
}
984
1160
985
1161
#[ test]
986
- fn test_shuffle ( ) {
1162
+ fn test_choose_from_iterator ( ) {
987
1163
let mut r = rng ( 108 ) ;
1164
+ let mut chosen = [ 0i32 ; 9 ] ;
1165
+ for _ in 0 ..1000 {
1166
+ let picked = r. choose_from_iterator ( 0 ..9 ) . unwrap ( ) ;
1167
+ chosen[ picked] += 1 ;
1168
+ }
1169
+ for count in chosen. iter ( ) {
1170
+ let err = * count - 1000 / 9 ;
1171
+ assert ! ( -25 <= err && err <= 25 ) ;
1172
+ }
1173
+
1174
+ assert_eq ! ( r. choose_from_iterator( 0 ..0 ) , None ) ;
1175
+ }
1176
+
1177
+ #[ test]
1178
+ fn test_shuffle ( ) {
1179
+ let mut r = rng ( 109 ) ;
988
1180
let empty: & mut [ isize ] = & mut [ ] ;
989
1181
r. shuffle ( empty) ;
990
1182
let mut one = [ 1 ] ;
@@ -1005,7 +1197,7 @@ mod test {
1005
1197
#[ test]
1006
1198
fn test_rng_trait_object ( ) {
1007
1199
use distributions:: { Distribution , Standard } ;
1008
- let mut rng = rng ( 109 ) ;
1200
+ let mut rng = rng ( 110 ) ;
1009
1201
let mut r = & mut rng as & mut RngCore ;
1010
1202
r. next_u32 ( ) ;
1011
1203
r. gen :: < i32 > ( ) ;
@@ -1021,7 +1213,7 @@ mod test {
1021
1213
#[ cfg( feature="alloc" ) ]
1022
1214
fn test_rng_boxed_trait ( ) {
1023
1215
use distributions:: { Distribution , Standard } ;
1024
- let rng = rng ( 110 ) ;
1216
+ let rng = rng ( 111 ) ;
1025
1217
let mut r = Box :: new ( rng) as Box < RngCore > ;
1026
1218
r. next_u32 ( ) ;
1027
1219
r. gen :: < i32 > ( ) ;
@@ -1049,6 +1241,7 @@ mod test {
1049
1241
}
1050
1242
1051
1243
#[ test]
1244
+ <<<<<<< HEAD
1052
1245
fn test_gen_ratio_average( ) {
1053
1246
const NUM : u32 = 3 ;
1054
1247
const DENOM : u32 = 10 ;
@@ -1063,5 +1256,101 @@ mod test {
1063
1256
}
1064
1257
let avg = ( sum as f64 ) / ( N as f64 ) ;
1065
1258
assert ! ( ( avg - ( NUM as f64 ) /( DENOM as f64 ) ) . abs( ) < 1e-3 ) ;
1259
+ =======
1260
+ fn test_choose_weighted ( ) {
1261
+ let mut r = rng ( 112 ) ;
1262
+ let chars = [ 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' ] ;
1263
+ let weights = [ 1u32 , 2 , 3 , 0 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ;
1264
+ let total_weight = weights. iter ( ) . sum ( ) ;
1265
+ assert_eq ! ( chars. len( ) , weights. len( ) ) ;
1266
+
1267
+ let mut chosen = [ 0i32 ; 14 ] ;
1268
+ for _ in 0 ..1000 {
1269
+ let picked = * r. choose_weighted ( chars. iter ( ) ,
1270
+ weights. iter ( ) . cloned ( ) ) . unwrap ( ) ;
1271
+ chosen[ ( picked as usize ) - ( 'a' as usize ) ] += 1 ;
1272
+ }
1273
+ for ( i, count) in chosen. iter ( ) . enumerate ( ) {
1274
+ let err = * count - ( ( weights[ i] * 1000 / total_weight) as i32 ) ;
1275
+ assert ! ( -25 <= err && err <= 25 ) ;
1276
+ }
1277
+
1278
+ // Mutable items
1279
+ chosen. iter_mut ( ) . for_each ( |x| * x = 0 ) ;
1280
+ for _ in 0 ..1000 {
1281
+ * r. choose_weighted ( chosen. iter_mut ( ) ,
1282
+ weights. iter ( ) . cloned ( ) ) . unwrap ( ) += 1 ;
1283
+ }
1284
+ for ( i, count) in chosen. iter ( ) . enumerate ( ) {
1285
+ let err = * count - ( ( weights[ i] * 1000 / total_weight) as i32 ) ;
1286
+ assert ! ( -25 <= err && err <= 25 ) ;
1287
+ }
1288
+
1289
+ // choose_weighted_with_total
1290
+ chosen. iter_mut ( ) . for_each ( |x| * x = 0 ) ;
1291
+ for _ in 0 ..1000 {
1292
+ let picked = * r. choose_weighted_with_total ( chars. iter ( ) ,
1293
+ weights. iter ( ) . cloned ( ) ,
1294
+ total_weight) . unwrap ( ) ;
1295
+ chosen[ ( picked as usize ) - ( 'a' as usize ) ] += 1 ;
1296
+ }
1297
+ for ( i, count) in chosen. iter ( ) . enumerate ( ) {
1298
+ let err = * count - ( ( weights[ i] * 1000 / total_weight) as i32 ) ;
1299
+ assert ! ( -25 <= err && err <= 25 ) ;
1300
+ }
1301
+ }
1302
+
1303
+ #[ test]
1304
+ #[ cfg( all( feature="std" ,
1305
+ not( target_arch = "wasm32" ) ,
1306
+ not( target_arch = "asmjs" ) ) ) ]
1307
+ fn test_choose_weighted_assertions ( ) {
1308
+ fn inner_delta ( delta : i32 ) {
1309
+ let items = vec ! [ 1 , 2 , 3 ] ;
1310
+ let mut r = rng ( 113 ) ;
1311
+ if cfg ! ( debug_assertions) || delta == 0 {
1312
+ r. choose_weighted_with_total ( items. iter ( ) ,
1313
+ items. iter ( ) . cloned ( ) ,
1314
+ 6 +delta) ;
1315
+ } else {
1316
+ loop {
1317
+ r. choose_weighted_with_total ( items. iter ( ) ,
1318
+ items. iter ( ) . cloned ( ) ,
1319
+ 6 +delta) ;
1320
+ }
1321
+ }
1322
+ }
1323
+
1324
+ assert ! ( catch_unwind( || inner_delta( 0 ) ) . is_ok( ) ) ;
1325
+ assert ! ( catch_unwind( || inner_delta( 1 ) ) . is_err( ) ) ;
1326
+ assert ! ( catch_unwind( || inner_delta( 1000 ) ) . is_err( ) ) ;
1327
+ if cfg ! ( debug_assertions) {
1328
+ // The non-debug-assertions code can't detect too small total_weight
1329
+ assert ! ( catch_unwind( || inner_delta( -1 ) ) . is_err( ) ) ;
1330
+ assert ! ( catch_unwind( || inner_delta( -1000 ) ) . is_err( ) ) ;
1331
+ }
1332
+
1333
+ fn inner_size ( items : usize , weights : usize , with_total : bool ) {
1334
+ let mut r = rng ( 114 ) ;
1335
+ if with_total {
1336
+ r. choose_weighted_with_total ( core:: iter:: repeat ( 1usize ) . take ( items) ,
1337
+ core:: iter:: repeat ( 1usize ) . take ( weights) ,
1338
+ weights) ;
1339
+ } else {
1340
+ r. choose_weighted ( core:: iter:: repeat ( 1usize ) . take ( items) ,
1341
+ core:: iter:: repeat ( 1usize ) . take ( weights) ) ;
1342
+ }
1343
+ }
1344
+
1345
+ assert ! ( catch_unwind( || inner_size( 2 , 2 , true ) ) . is_ok( ) ) ;
1346
+ assert ! ( catch_unwind( || inner_size( 2 , 2 , false ) ) . is_ok( ) ) ;
1347
+ assert ! ( catch_unwind( || inner_size( 2 , 1 , true ) ) . is_err( ) ) ;
1348
+ assert ! ( catch_unwind( || inner_size( 2 , 1 , false ) ) . is_err( ) ) ;
1349
+ if cfg ! ( debug_assertions) {
1350
+ // The non-debug-assertions code can't detect too many weights
1351
+ assert ! ( catch_unwind( || inner_size( 2 , 3 , true ) ) . is_err( ) ) ;
1352
+ assert ! ( catch_unwind( || inner_size( 2 , 3 , false ) ) . is_err( ) ) ;
1353
+ }
1354
+ >>>>>>> Implement choose/choose_mut/choose_from_iterator on both Rng and on slice/Iterator
1066
1355
}
1067
1356
}
0 commit comments