Skip to content

Commit 2a786a8

Browse files
committed
Implement choose/choose_mut/choose_from_iterator on both Rng and on slice/Iterator
Implement Rng.choose_weighted() and Rng.choose_weighted_with_total()
1 parent ec3d7ef commit 2a786a8

File tree

2 files changed

+396
-6
lines changed

2 files changed

+396
-6
lines changed

src/lib.rs

+295-6
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ pub mod prelude;
269269
pub mod prng;
270270
pub mod rngs;
271271
#[cfg(feature = "alloc")] pub mod seq;
272+
#[cfg(feature = "alloc")] pub use seq::{SliceRandom, IteratorRandom};
272273

273274
////////////////////////////////////////////////////////////////////////////////
274275
// Compatibility re-exports. Documentation is hidden; will be removed eventually.
@@ -595,6 +596,161 @@ pub trait Rng: RngCore {
595596
}
596597
}
597598

599+
/// Returns one random element of the `Iterator`, or `None` if the
600+
/// `Iterator` returns no items. If you have a slice, it's significantly
601+
/// faster to call the [`choose`] or [`choose_mut`] functions using the
602+
/// slice instead. However it expected to be faster than dumping the
603+
/// Iterator into a slice and then calling [`choose`]/[`choose_mut`] on
604+
/// the slice.
605+
///
606+
/// # Example
607+
///
608+
/// ```
609+
/// use rand::{thread_rng, Rng};
610+
///
611+
/// let choices = std::iter::repeat(0)
612+
/// .scan((1, 1), |state, _| { let (a, b) = *state; *state = (b, a+b); Some(a) })
613+
/// .take(40);
614+
/// let mut rng = thread_rng();
615+
/// // Randomly choose one of the first 40 fibonacci numbers
616+
/// println!("{}", rng.choose_from_iterator(choices).unwrap());
617+
/// assert_eq!(rng.choose_from_iterator(std::iter::empty::<i32>()), None);
618+
/// ```
619+
/// [`choose`]: trait.Rng.html#method.choose
620+
/// [`choose_mut`]: trait.Rng.html#method.choose_mut
621+
fn choose_from_iterator<I: Iterator>(&mut self, mut iterable: I) -> Option<I::Item> {
622+
let mut val = iterable.next();
623+
if val.is_none() {
624+
return val;
625+
}
626+
627+
for (i, elem) in iterable.enumerate() {
628+
if self.gen_range(0, i + 2) == 0 {
629+
val = Some(elem);
630+
}
631+
}
632+
val
633+
}
634+
635+
/// Return a random element from `items` where. The chance of a given item
636+
/// being picked, is proportional to the corresponding value in `weights`.
637+
/// `weights` and `items` must return exactly the same number of values.
638+
///
639+
/// All values returned by `weights` must be `>= 0`.
640+
///
641+
/// This function iterates over `weights` twice. Once to get the total
642+
/// weight, and once while choosing the random value. If you know the total
643+
/// weight, or plan to call this function multiple times, you should
644+
/// consider using [`choose_weighted_with_total`] instead.
645+
///
646+
/// Return `None` if `items` and `weights` is empty.
647+
///
648+
/// # Example
649+
///
650+
/// ```
651+
/// use rand::{thread_rng, Rng};
652+
///
653+
/// let choices = ['a', 'b', 'c'];
654+
/// let weights = [2, 1, 1];
655+
/// let mut rng = thread_rng();
656+
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
657+
/// println!("{}", rng.choose_weighted(choices.iter(), weights.iter().cloned()).unwrap());
658+
/// ```
659+
/// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
660+
fn choose_weighted<IterItems, IterWeights>(&mut self,
661+
items: IterItems,
662+
weights: IterWeights) -> Option<IterItems::Item>
663+
where IterItems: Iterator,
664+
IterWeights: Iterator+Clone,
665+
IterWeights::Item: SampleUniform +
666+
Default +
667+
core::ops::Add<IterWeights::Item, Output=IterWeights::Item> +
668+
core::cmp::PartialOrd<IterWeights::Item> +
669+
Clone { // Clone is only needed for debug assertions
670+
let total_weight: IterWeights::Item =
671+
weights.clone().fold(Default::default(), |acc, w| {
672+
assert!(w >= Default::default(), "Weight must be larger than zero");
673+
acc + w
674+
});
675+
self.choose_weighted_with_total(items, weights, total_weight)
676+
}
677+
678+
/// Return a random element from `items` where. The chance of a given item
679+
/// being picked, is proportional to the corresponding value in `weights`.
680+
/// `weights` and `items` must return exactly the same number of values.
681+
///
682+
/// All values returned by `weights` must be `>= 0`.
683+
///
684+
/// `total_weight` must be exactly the sum of all values returned by
685+
/// `weights`. Builds with debug_assertions turned on will assert that this
686+
/// equality holds. Simply storing the result of `weights.sum()` and using
687+
/// that as `total_weight` should work.
688+
///
689+
/// Return `None` if `items` and `weights` is empty.
690+
///
691+
/// # Example
692+
///
693+
/// ```
694+
/// use rand::{thread_rng, Rng};
695+
///
696+
/// let choices = ['a', 'b', 'c'];
697+
/// let weights = [2, 1, 1];
698+
/// let mut rng = thread_rng();
699+
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
700+
/// println!("{}", rng.choose_weighted_with_total(choices.iter(), weights.iter().cloned(), 4).unwrap());
701+
/// ```
702+
/// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
703+
fn choose_weighted_with_total<IterItems, IterWeights>(&mut self,
704+
mut items: IterItems,
705+
mut weights: IterWeights,
706+
total_weight: IterWeights::Item) -> Option<IterItems::Item>
707+
where IterItems: Iterator,
708+
IterWeights: Iterator,
709+
IterWeights::Item: SampleUniform +
710+
Default +
711+
core::ops::Add<IterWeights::Item, Output=IterWeights::Item> +
712+
core::cmp::PartialOrd<IterWeights::Item> +
713+
Clone { // Clone is only needed for debug assertions
714+
715+
if total_weight == Default::default() {
716+
debug_assert!(items.next().is_none());
717+
return None;
718+
}
719+
720+
// Only used when debug_assertions are turned on
721+
let mut debug_result = None;
722+
let debug_total_weight = if cfg!(debug_assertions) { Some(total_weight.clone()) } else { None };
723+
724+
let chosen_weight = self.gen_range(Default::default(), total_weight);
725+
let mut cumulative_weight: IterWeights::Item = Default::default();
726+
727+
for item in items {
728+
let weight_opt = weights.next();
729+
assert!(weight_opt.is_some(), "`weights` returned fewer items than `items` did");
730+
let weight = weight_opt.unwrap();
731+
assert!(weight >= Default::default(), "Weight must be larger than zero");
732+
733+
cumulative_weight = cumulative_weight + weight;
734+
735+
if cumulative_weight > chosen_weight {
736+
if !cfg!(debug_assertions) {
737+
return Some(item);
738+
}
739+
if debug_result.is_none() {
740+
debug_result = Some(item);
741+
}
742+
}
743+
}
744+
745+
assert!(weights.next().is_none(), "`weights` returned more items than `items` did");
746+
debug_assert!(debug_total_weight.unwrap() == cumulative_weight);
747+
if cfg!(debug_assertions) && debug_result.is_some() {
748+
return debug_result;
749+
}
750+
751+
panic!("total_weight did not match up with sum of weights");
752+
}
753+
598754
/// Shuffle a mutable slice in place.
599755
///
600756
/// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle](
@@ -846,6 +1002,7 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
8461002
#[cfg(test)]
8471003
mod test {
8481004
use rngs::mock::StepRng;
1005+
#[cfg(feature="std")] use core::panic::catch_unwind;
8491006
use super::*;
8501007
#[cfg(all(not(feature="std"), feature="alloc"))] use alloc::boxed::Box;
8511008

@@ -976,15 +1133,50 @@ mod test {
9761133
#[test]
9771134
fn test_choose() {
9781135
let mut r = rng(107);
979-
assert_eq!(r.choose(&[1, 1, 1]).map(|&x|x), Some(1));
1136+
let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'];
1137+
let mut chosen = [0i32; 14];
1138+
for _ in 0..1000 {
1139+
let picked = *r.choose(&chars).unwrap();
1140+
chosen[(picked as usize) - ('a' as usize)] += 1;
1141+
}
1142+
for count in chosen.iter() {
1143+
let err = *count - (1000 / (chars.len() as i32));
1144+
assert!(-20 <= err && err <= 20);
1145+
}
9801146

981-
let v: &[isize] = &[];
982-
assert_eq!(r.choose(v), None);
1147+
chosen.iter_mut().for_each(|x| *x = 0);
1148+
for _ in 0..1000 {
1149+
*r.choose_mut(&mut chosen).unwrap() += 1;
1150+
}
1151+
for count in chosen.iter() {
1152+
let err = *count - (1000 / (chosen.len() as i32));
1153+
assert!(-20 <= err && err <= 20);
1154+
}
1155+
1156+
let mut v: [isize; 0] = [];
1157+
assert_eq!(r.choose(&v), None);
1158+
assert_eq!(r.choose_mut(&mut v), None);
9831159
}
9841160

9851161
#[test]
986-
fn test_shuffle() {
1162+
fn test_choose_from_iterator() {
9871163
let mut r = rng(108);
1164+
let mut chosen = [0i32; 9];
1165+
for _ in 0..1000 {
1166+
let picked = r.choose_from_iterator(0..9).unwrap();
1167+
chosen[picked] += 1;
1168+
}
1169+
for count in chosen.iter() {
1170+
let err = *count - 1000 / 9;
1171+
assert!(-25 <= err && err <= 25);
1172+
}
1173+
1174+
assert_eq!(r.choose_from_iterator(0..0), None);
1175+
}
1176+
1177+
#[test]
1178+
fn test_shuffle() {
1179+
let mut r = rng(109);
9881180
let empty: &mut [isize] = &mut [];
9891181
r.shuffle(empty);
9901182
let mut one = [1];
@@ -1005,7 +1197,7 @@ mod test {
10051197
#[test]
10061198
fn test_rng_trait_object() {
10071199
use distributions::{Distribution, Standard};
1008-
let mut rng = rng(109);
1200+
let mut rng = rng(110);
10091201
let mut r = &mut rng as &mut RngCore;
10101202
r.next_u32();
10111203
r.gen::<i32>();
@@ -1021,7 +1213,7 @@ mod test {
10211213
#[cfg(feature="alloc")]
10221214
fn test_rng_boxed_trait() {
10231215
use distributions::{Distribution, Standard};
1024-
let rng = rng(110);
1216+
let rng = rng(111);
10251217
let mut r = Box::new(rng) as Box<RngCore>;
10261218
r.next_u32();
10271219
r.gen::<i32>();
@@ -1049,6 +1241,7 @@ mod test {
10491241
}
10501242

10511243
#[test]
1244+
<<<<<<< HEAD
10521245
fn test_gen_ratio_average() {
10531246
const NUM: u32 = 3;
10541247
const DENOM: u32 = 10;
@@ -1063,5 +1256,101 @@ mod test {
10631256
}
10641257
let avg = (sum as f64) / (N as f64);
10651258
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
1259+
=======
1260+
fn test_choose_weighted() {
1261+
let mut r = rng(112);
1262+
let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'];
1263+
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
1264+
let total_weight = weights.iter().sum();
1265+
assert_eq!(chars.len(), weights.len());
1266+
1267+
let mut chosen = [0i32; 14];
1268+
for _ in 0..1000 {
1269+
let picked = *r.choose_weighted(chars.iter(),
1270+
weights.iter().cloned()).unwrap();
1271+
chosen[(picked as usize) - ('a' as usize)] += 1;
1272+
}
1273+
for (i, count) in chosen.iter().enumerate() {
1274+
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
1275+
assert!(-25 <= err && err <= 25);
1276+
}
1277+
1278+
// Mutable items
1279+
chosen.iter_mut().for_each(|x| *x = 0);
1280+
for _ in 0..1000 {
1281+
*r.choose_weighted(chosen.iter_mut(),
1282+
weights.iter().cloned()).unwrap() += 1;
1283+
}
1284+
for (i, count) in chosen.iter().enumerate() {
1285+
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
1286+
assert!(-25 <= err && err <= 25);
1287+
}
1288+
1289+
// choose_weighted_with_total
1290+
chosen.iter_mut().for_each(|x| *x = 0);
1291+
for _ in 0..1000 {
1292+
let picked = *r.choose_weighted_with_total(chars.iter(),
1293+
weights.iter().cloned(),
1294+
total_weight).unwrap();
1295+
chosen[(picked as usize) - ('a' as usize)] += 1;
1296+
}
1297+
for (i, count) in chosen.iter().enumerate() {
1298+
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
1299+
assert!(-25 <= err && err <= 25);
1300+
}
1301+
}
1302+
1303+
#[test]
1304+
#[cfg(all(feature="std",
1305+
not(target_arch = "wasm32"),
1306+
not(target_arch = "asmjs")))]
1307+
fn test_choose_weighted_assertions() {
1308+
fn inner_delta(delta: i32) {
1309+
let items = vec![1, 2, 3];
1310+
let mut r = rng(113);
1311+
if cfg!(debug_assertions) || delta == 0 {
1312+
r.choose_weighted_with_total(items.iter(),
1313+
items.iter().cloned(),
1314+
6+delta);
1315+
} else {
1316+
loop {
1317+
r.choose_weighted_with_total(items.iter(),
1318+
items.iter().cloned(),
1319+
6+delta);
1320+
}
1321+
}
1322+
}
1323+
1324+
assert!(catch_unwind(|| inner_delta(0)).is_ok());
1325+
assert!(catch_unwind(|| inner_delta(1)).is_err());
1326+
assert!(catch_unwind(|| inner_delta(1000)).is_err());
1327+
if cfg!(debug_assertions) {
1328+
// The non-debug-assertions code can't detect too small total_weight
1329+
assert!(catch_unwind(|| inner_delta(-1)).is_err());
1330+
assert!(catch_unwind(|| inner_delta(-1000)).is_err());
1331+
}
1332+
1333+
fn inner_size(items: usize, weights: usize, with_total: bool) {
1334+
let mut r = rng(114);
1335+
if with_total {
1336+
r.choose_weighted_with_total(core::iter::repeat(1usize).take(items),
1337+
core::iter::repeat(1usize).take(weights),
1338+
weights);
1339+
} else {
1340+
r.choose_weighted(core::iter::repeat(1usize).take(items),
1341+
core::iter::repeat(1usize).take(weights));
1342+
}
1343+
}
1344+
1345+
assert!(catch_unwind(|| inner_size(2, 2, true)).is_ok());
1346+
assert!(catch_unwind(|| inner_size(2, 2, false)).is_ok());
1347+
assert!(catch_unwind(|| inner_size(2, 1, true)).is_err());
1348+
assert!(catch_unwind(|| inner_size(2, 1, false)).is_err());
1349+
if cfg!(debug_assertions) {
1350+
// The non-debug-assertions code can't detect too many weights
1351+
assert!(catch_unwind(|| inner_size(2, 3, true)).is_err());
1352+
assert!(catch_unwind(|| inner_size(2, 3, false)).is_err());
1353+
}
1354+
>>>>>>> Implement choose/choose_mut/choose_from_iterator on both Rng and on slice/Iterator
10661355
}
10671356
}

0 commit comments

Comments
 (0)