Skip to content

Commit c032204

Browse files
authored
Merge pull request #547 from sicking/weighterr
Add WeightedError
2 parents 40d8c39 + 3134986 commit c032204

File tree

3 files changed

+60
-24
lines changed

3 files changed

+60
-24
lines changed

src/distributions/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ use Rng;
177177
#[doc(inline)] pub use self::uniform::Uniform;
178178
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
179179
#[cfg(feature="alloc")]
180-
#[doc(inline)] pub use self::weighted::WeightedIndex;
180+
#[doc(inline)] pub use self::weighted::{WeightedIndex, WeightedError};
181181
#[cfg(feature="std")]
182182
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
183183
#[cfg(feature="std")]

src/distributions/weighted.rs

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use Rng;
1212
use distributions::Distribution;
1313
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
1414
use ::core::cmp::PartialOrd;
15-
use ::{Error, ErrorKind};
15+
use core::fmt;
1616

1717
// Note that this whole module is only imported if feature="alloc" is enabled.
1818
#[cfg(not(feature="std"))] use alloc::vec::Vec;
@@ -63,34 +63,34 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
6363
///
6464
/// [`Distribution`]: trait.Distribution.html
6565
/// [`Uniform<X>`]: struct.Uniform.html
66-
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
66+
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
6767
where I: IntoIterator,
6868
I::Item: SampleBorrow<X>,
6969
X: for<'a> ::core::ops::AddAssign<&'a X> +
7070
Clone +
7171
Default {
7272
let mut iter = weights.into_iter();
7373
let mut total_weight: X = iter.next()
74-
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
74+
.ok_or(WeightedError::NoItem)?
7575
.borrow()
7676
.clone();
7777

7878
let zero = <X as Default>::default();
7979
if total_weight < zero {
80-
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
80+
return Err(WeightedError::NegativeWeight);
8181
}
8282

8383
let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
8484
for w in iter {
8585
if *w.borrow() < zero {
86-
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
86+
return Err(WeightedError::NegativeWeight);
8787
}
8888
weights.push(total_weight.clone());
8989
total_weight += w.borrow();
9090
}
9191

9292
if total_weight == zero {
93-
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
93+
return Err(WeightedError::AllWeightsZero);
9494
}
9595
let distr = X::Sampler::new(zero, total_weight);
9696

@@ -161,10 +161,43 @@ mod test {
161161
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
162162
}
163163

164-
assert!(WeightedIndex::new(&[10][0..0]).is_err());
165-
assert!(WeightedIndex::new(&[0]).is_err());
166-
assert!(WeightedIndex::new(&[10, 20, -1, 30]).is_err());
167-
assert!(WeightedIndex::new(&[-10, 20, 1, 30]).is_err());
168-
assert!(WeightedIndex::new(&[-10]).is_err());
164+
assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
165+
assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
166+
assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
167+
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
168+
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
169+
}
170+
}
171+
172+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173+
pub enum WeightedError {
174+
NoItem,
175+
NegativeWeight,
176+
AllWeightsZero,
177+
}
178+
179+
impl WeightedError {
180+
fn msg(&self) -> &str {
181+
match *self {
182+
WeightedError::NoItem => "No items found",
183+
WeightedError::NegativeWeight => "Item has negative weight",
184+
WeightedError::AllWeightsZero => "All items had weight zero",
185+
}
186+
}
187+
}
188+
189+
#[cfg(feature="std")]
190+
impl ::std::error::Error for WeightedError {
191+
fn description(&self) -> &str {
192+
self.msg()
193+
}
194+
fn cause(&self) -> Option<&::std::error::Error> {
195+
None
196+
}
197+
}
198+
199+
impl fmt::Display for WeightedError {
200+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
201+
write!(f, "{}", self.msg())
169202
}
170203
}

src/seq.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#[cfg(feature="std")] use std::collections::HashMap;
2222
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap;
2323

24+
#[cfg(feature = "alloc")] use distributions::WeightedError;
25+
2426
use super::Rng;
2527
#[cfg(feature="alloc")] use distributions::uniform::{SampleUniform, SampleBorrow};
2628

@@ -109,7 +111,7 @@ pub trait SliceRandom {
109111
/// ```
110112
/// [`choose`]: trait.SliceRandom.html#method.choose
111113
#[cfg(feature = "alloc")]
112-
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Option<&Self::Item>
114+
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError>
113115
where R: Rng + ?Sized,
114116
F: Fn(&Self::Item) -> B,
115117
B: SampleBorrow<X>,
@@ -129,7 +131,7 @@ pub trait SliceRandom {
129131
/// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut
130132
/// [`choose_weighted`]: trait.SliceRandom.html#method.choose_weighted
131133
#[cfg(feature = "alloc")]
132-
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item>
134+
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError>
133135
where R: Rng + ?Sized,
134136
F: Fn(&Self::Item) -> B,
135137
B: SampleBorrow<X>,
@@ -327,7 +329,7 @@ impl<T> SliceRandom for [T] {
327329
}
328330

329331
#[cfg(feature = "alloc")]
330-
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Option<&Self::Item>
332+
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError>
331333
where R: Rng + ?Sized,
332334
F: Fn(&Self::Item) -> B,
333335
B: SampleBorrow<X>,
@@ -337,12 +339,12 @@ impl<T> SliceRandom for [T] {
337339
Clone +
338340
Default {
339341
use distributions::{Distribution, WeightedIndex};
340-
WeightedIndex::new(self.iter().map(weight)).ok()
341-
.map(|distr| &self[distr.sample(rng)])
342+
let distr = WeightedIndex::new(self.iter().map(weight))?;
343+
Ok(&self[distr.sample(rng)])
342344
}
343345

344346
#[cfg(feature = "alloc")]
345-
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item>
347+
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError>
346348
where R: Rng + ?Sized,
347349
F: Fn(&Self::Item) -> B,
348350
B: SampleBorrow<X>,
@@ -352,9 +354,8 @@ impl<T> SliceRandom for [T] {
352354
Clone +
353355
Default {
354356
use distributions::{Distribution, WeightedIndex};
355-
WeightedIndex::new(self.iter().map(weight)).ok()
356-
.map(|distr| distr.sample(rng))
357-
.map(move |ix| &mut self[ix])
357+
let distr = WeightedIndex::new(self.iter().map(weight))?;
358+
Ok(&mut self[distr.sample(rng)])
358359
}
359360

360361
fn shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized
@@ -868,8 +869,10 @@ mod test {
868869

869870
// Check error cases
870871
let empty_slice = &mut [10][0..0];
871-
assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), None);
872-
assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), None);
873-
assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), None);
872+
assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), Err(WeightedError::NoItem));
873+
assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), Err(WeightedError::NoItem));
874+
assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), Err(WeightedError::AllWeightsZero));
875+
assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight));
876+
assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight));
874877
}
875878
}

0 commit comments

Comments
 (0)