Skip to content

Don't require &mut ref in WeightedChoice::new #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 56 additions & 39 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,53 +114,53 @@ pub struct Weighted<T> {
/// }
/// ```
#[derive(Debug)]
pub struct WeightedChoice<'a, T:'a> {
items: &'a mut [Weighted<T>],
weight_range: Range<u32>
pub struct WeightedChoice<T: AsMut<[Weighted<U>]> + AsRef<[Weighted<U>]>, U: Clone> {
items: T,
weight_range: Range<u32>,
_phantom: marker::PhantomData<U>,
}

impl<'a, T: Clone> WeightedChoice<'a, T> {
/// Create a new `WeightedChoice`.
///
/// Panics if:
/// - `v` is empty
/// - the total weight is 0
/// - the total weight is larger than a `u32` can contain.
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
// strictly speaking, this is subsumed by the total weight == 0 case
assert!(!items.is_empty(), "WeightedChoice::new called with no items");

impl<T: AsMut<[Weighted<U>]> + AsRef<[Weighted<U>]>, U: Clone> WeightedChoice<T, U> {
pub fn new(mut items_init: T) -> WeightedChoice<T, U> {
let mut running_total: u32 = 0;

// we convert the list from individual weights to cumulative
// weights so we can binary search. This *could* drop elements
// with weight == 0 as an optimisation.
for item in items.iter_mut() {
running_total = match running_total.checked_add(item.weight) {
Some(n) => n,
None => panic!("WeightedChoice::new called with a total weight \
larger than a u32 can contain")
};

item.weight = running_total;
{
let items = items_init.as_mut();
// strictly speaking, this is subsumed by the total weight == 0 case
assert!(!items.is_empty(), "WeightedChoice::new called with no items");

// we convert the list from individual weights to cumulative
// weights so we can binary search. This *could* drop elements
// with weight == 0 as an optimisation.
for item in items.iter_mut() {
running_total = match running_total.checked_add(item.weight) {
Some(n) => n,
None => panic!("WeightedChoice::new called with a total weight \
larger than a u32 can contain")
};

item.weight = running_total;
}
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
}
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");

WeightedChoice {
items: items,
items: items_init,
// we're likely to be generating numbers in this range
// relatively often, so might as well cache it
weight_range: Range::new(0, running_total)
weight_range: Range::new(0, running_total),
_phantom: marker::PhantomData,
}
}
}

impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }


impl<T: AsMut<[Weighted<U>]> + AsRef<[Weighted<U>]>, U: Clone> Sample<U> for WeightedChoice<T, U> {
fn sample<R: Rng>(&mut self, rng: &mut R) -> U { self.ind_sample(rng) }
}

impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
impl<T: AsMut<[Weighted<U>]> + AsRef<[Weighted<U>]>, U: Clone> IndependentSample<U> for WeightedChoice<T, U> {
fn ind_sample<R: Rng>(&self, rng: &mut R) -> U {
// we want to find the first element that has cumulative
// weight > sample_weight, which we do by binary since the
// cumulative weights of self.items are sorted.
Expand All @@ -169,12 +169,12 @@ impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
let sample_weight = self.weight_range.ind_sample(rng);

// short circuit when it's the first item
if sample_weight < self.items[0].weight {
return self.items[0].item.clone();
if sample_weight < self.items.as_ref()[0].weight {
return self.items.as_ref()[0].item.clone();
}

let mut idx = 0;
let mut modifier = self.items.len();
let mut modifier = self.items.as_ref().len();

// now we know that every possibility has an element to the
// left, so we can just search for the last element that has
Expand All @@ -185,7 +185,7 @@ impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
// one is exactly the total weight.)
while modifier > 1 {
let i = idx + modifier / 2;
if self.items[i].weight <= sample_weight {
if self.items.as_ref()[i].weight <= sample_weight {
// we're small, so look to the right, but allow this
// exact element still.
idx = i;
Expand All @@ -198,7 +198,7 @@ impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
}
modifier /= 2;
}
return self.items[idx + 1].item.clone();
return self.items.as_ref()[idx + 1].item.clone();
}
}

Expand Down Expand Up @@ -382,13 +382,16 @@ mod tests {

#[test] #[should_panic]
fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []);
let xs: [Weighted<isize>; 0] = [];
WeightedChoice::new(xs);
}

#[test] #[should_panic]
fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
Weighted { weight: 0, item: 1}]);
}

#[test] #[should_panic]
fn test_weighted_choice_weight_overflows() {
let x = ::std::u32::MAX / 2; // x + x + 2 is the overflow
Expand All @@ -397,4 +400,18 @@ mod tests {
Weighted { weight: x, item: 2 },
Weighted { weight: 1, item: 3 }]);
}

#[test]
fn test_construct_weighted_choice() {
/// We should be able to construct a WeightedChoice in a function and
/// return it, for great ergonomics.
fn mk() -> WeightedChoice<Vec<Weighted<usize>>, usize> {
let inputs = vec![
Weighted { weight: 1, item: 0 },
Weighted { weight: 1, item: 1 },
];
WeightedChoice::new(inputs)
}
let _ = mk();
}
}