Skip to content

Commit e9a27a8

Browse files
authored
Trap weighted index overflow (#1353)
* WeightedIndex: add test overflow (expected to panic) * WeightedIndex::new: trap overflow in release builds only * Introduce trait Weight * Update regarding nightly SIMD changes
1 parent 3c2e82f commit e9a27a8

File tree

7 files changed

+88
-35
lines changed

7 files changed

+88
-35
lines changed

src/distributions/float.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
1212
use crate::distributions::{Distribution, Standard};
1313
use crate::Rng;
1414
use core::mem;
15-
#[cfg(feature = "simd_support")] use core::simd::*;
15+
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
1616

1717
#[cfg(feature = "serde1")]
1818
use serde::{Serialize, Deserialize};

src/distributions/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ pub use self::slice::Slice;
126126
#[doc(inline)]
127127
pub use self::uniform::Uniform;
128128
#[cfg(feature = "alloc")]
129-
pub use self::weighted_index::{WeightedError, WeightedIndex};
129+
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};
130130

131131
#[allow(unused)]
132132
use crate::Rng;

src/distributions/other.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ use crate::Rng;
2222
use serde::{Serialize, Deserialize};
2323
use core::mem::{self, MaybeUninit};
2424
#[cfg(feature = "simd_support")]
25-
use core::simd::*;
25+
use core::simd::prelude::*;
26+
#[cfg(feature = "simd_support")]
27+
use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
2628

2729

2830
// ----- Sampling distributions -----
@@ -163,7 +165,7 @@ impl Distribution<bool> for Standard {
163165
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
164166
/// ```
165167
/// #![feature(portable_simd)]
166-
/// use std::simd::*;
168+
/// use std::simd::prelude::*;
167169
/// use rand::prelude::*;
168170
/// let mut rng = thread_rng();
169171
///

src/distributions/uniform.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ use crate::{Rng, RngCore};
119119
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
120120
use crate::distributions::utils::Float;
121121

122-
#[cfg(feature = "simd_support")] use core::simd::*;
122+
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
123+
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount};
123124

124125
/// Error type returned from [`Uniform::new`] and `new_inclusive`.
125126
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -1433,7 +1434,7 @@ mod tests {
14331434
(-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7),
14341435
];
14351436
for &(low_scalar, high_scalar) in v.iter() {
1436-
for lane in 0..<$ty>::LANES {
1437+
for lane in 0..<$ty>::LEN {
14371438
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
14381439
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
14391440
let my_uniform = Uniform::new(low, high).unwrap();
@@ -1565,7 +1566,7 @@ mod tests {
15651566
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY),
15661567
];
15671568
for &(low_scalar, high_scalar) in v.iter() {
1568-
for lane in 0..<$ty>::LANES {
1569+
for lane in 0..<$ty>::LEN {
15691570
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
15701571
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
15711572
assert!(catch_unwind(|| range(low, high)).is_err());

src/distributions/utils.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
//! Math helper functions
1010
11-
#[cfg(feature = "simd_support")] use core::simd::*;
11+
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
12+
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
1213

1314

1415
pub(crate) trait WideningMultiply<RHS = Self> {
@@ -245,7 +246,7 @@ pub(crate) trait Float: Sized {
245246

246247
/// Implement functions on f32/f64 to give them APIs similar to SIMD types
247248
pub(crate) trait FloatAsSIMD: Sized {
248-
const LANES: usize = 1;
249+
const LEN: usize = 1;
249250
#[inline(always)]
250251
fn splat(scalar: Self) -> Self {
251252
scalar

src/distributions/weighted_index.rs

+70-5
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
9999
where
100100
I: IntoIterator,
101101
I::Item: SampleBorrow<X>,
102-
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
102+
X: Weight,
103103
{
104104
let mut iter = weights.into_iter();
105105
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
106106

107-
let zero = <X as Default>::default();
107+
let zero = X::ZERO;
108108
if !(total_weight >= zero) {
109109
return Err(WeightedError::InvalidWeight);
110110
}
@@ -117,7 +117,10 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
117117
return Err(WeightedError::InvalidWeight);
118118
}
119119
weights.push(total_weight.clone());
120-
total_weight += w.borrow();
120+
121+
if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
122+
return Err(WeightedError::Overflow);
123+
}
121124
}
122125

123126
if total_weight == zero {
@@ -236,6 +239,60 @@ where X: SampleUniform + PartialOrd
236239
}
237240
}
238241

242+
/// Bounds on a weight
243+
///
244+
/// See usage in [`WeightedIndex`].
245+
pub trait Weight: Clone {
246+
/// Representation of 0
247+
const ZERO: Self;
248+
249+
/// Checked addition
250+
///
251+
/// - `Result::Ok`: On success, `v` is added to `self`
252+
/// - `Result::Err`: Returns an error when `Self` cannot represent the
253+
/// result of `self + v` (i.e. overflow). The value of `self` should be
254+
/// discarded.
255+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
256+
}
257+
258+
macro_rules! impl_weight_int {
259+
($t:ty) => {
260+
impl Weight for $t {
261+
const ZERO: Self = 0;
262+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
263+
match self.checked_add(*v) {
264+
Some(sum) => {
265+
*self = sum;
266+
Ok(())
267+
}
268+
None => Err(()),
269+
}
270+
}
271+
}
272+
};
273+
($t:ty, $($tt:ty),*) => {
274+
impl_weight_int!($t);
275+
impl_weight_int!($($tt),*);
276+
}
277+
}
278+
impl_weight_int!(i8, i16, i32, i64, i128, isize);
279+
impl_weight_int!(u8, u16, u32, u64, u128, usize);
280+
281+
macro_rules! impl_weight_float {
282+
($t:ty) => {
283+
impl Weight for $t {
284+
const ZERO: Self = 0.0;
285+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
286+
// Floats have an explicit representation for overflow
287+
*self += *v;
288+
Ok(())
289+
}
290+
}
291+
}
292+
}
293+
impl_weight_float!(f32);
294+
impl_weight_float!(f64);
295+
239296
#[cfg(test)]
240297
mod test {
241298
use super::*;
@@ -388,12 +445,11 @@ mod test {
388445

389446
#[test]
390447
fn value_stability() {
391-
fn test_samples<X: SampleUniform + PartialOrd, I>(
448+
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
392449
weights: I, buf: &mut [usize], expected: &[usize],
393450
) where
394451
I: IntoIterator,
395452
I::Item: SampleBorrow<X>,
396-
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
397453
{
398454
assert_eq!(buf.len(), expected.len());
399455
let distr = WeightedIndex::new(weights).unwrap();
@@ -420,6 +476,11 @@ mod test {
420476
fn weighted_index_distributions_can_be_compared() {
421477
assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
422478
}
479+
480+
#[test]
481+
fn overflow() {
482+
assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow));
483+
}
423484
}
424485

425486
/// Error type returned from `WeightedIndex::new`.
@@ -438,6 +499,9 @@ pub enum WeightedError {
438499

439500
/// Too many weights are provided (length greater than `u32::MAX`)
440501
TooMany,
502+
503+
/// The sum of weights overflows
504+
Overflow,
441505
}
442506

443507
#[cfg(feature = "std")]
@@ -450,6 +514,7 @@ impl fmt::Display for WeightedError {
450514
WeightedError::InvalidWeight => "A weight is invalid in distribution",
451515
WeightedError::AllWeightsZero => "All weights are zero in distribution",
452516
WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
517+
WeightedError::Overflow => "The sum of weights overflowed",
453518
})
454519
}
455520
}

src/seq/mod.rs

+5-21
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use alloc::vec::Vec;
4040
#[cfg(feature = "alloc")]
4141
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
4242
#[cfg(feature = "alloc")]
43-
use crate::distributions::WeightedError;
43+
use crate::distributions::{Weight, WeightedError};
4444
use crate::Rng;
4545

4646
use self::coin_flipper::CoinFlipper;
@@ -170,11 +170,7 @@ pub trait SliceRandom {
170170
R: Rng + ?Sized,
171171
F: Fn(&Self::Item) -> B,
172172
B: SampleBorrow<X>,
173-
X: SampleUniform
174-
+ for<'a> ::core::ops::AddAssign<&'a X>
175-
+ ::core::cmp::PartialOrd<X>
176-
+ Clone
177-
+ Default;
173+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;
178174

179175
/// Biased sampling for one element (mut)
180176
///
@@ -203,11 +199,7 @@ pub trait SliceRandom {
203199
R: Rng + ?Sized,
204200
F: Fn(&Self::Item) -> B,
205201
B: SampleBorrow<X>,
206-
X: SampleUniform
207-
+ for<'a> ::core::ops::AddAssign<&'a X>
208-
+ ::core::cmp::PartialOrd<X>
209-
+ Clone
210-
+ Default;
202+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;
211203

212204
/// Biased sampling of `amount` distinct elements
213205
///
@@ -585,11 +577,7 @@ impl<T> SliceRandom for [T] {
585577
R: Rng + ?Sized,
586578
F: Fn(&Self::Item) -> B,
587579
B: SampleBorrow<X>,
588-
X: SampleUniform
589-
+ for<'a> ::core::ops::AddAssign<&'a X>
590-
+ ::core::cmp::PartialOrd<X>
591-
+ Clone
592-
+ Default,
580+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
593581
{
594582
use crate::distributions::{Distribution, WeightedIndex};
595583
let distr = WeightedIndex::new(self.iter().map(weight))?;
@@ -604,11 +592,7 @@ impl<T> SliceRandom for [T] {
604592
R: Rng + ?Sized,
605593
F: Fn(&Self::Item) -> B,
606594
B: SampleBorrow<X>,
607-
X: SampleUniform
608-
+ for<'a> ::core::ops::AddAssign<&'a X>
609-
+ ::core::cmp::PartialOrd<X>
610-
+ Clone
611-
+ Default,
595+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
612596
{
613597
use crate::distributions::{Distribution, WeightedIndex};
614598
let distr = WeightedIndex::new(self.iter().map(weight))?;

0 commit comments

Comments
 (0)