@@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
99
99
where
100
100
I : IntoIterator ,
101
101
I :: Item : SampleBorrow < X > ,
102
- X : for < ' a > :: core :: ops :: AddAssign < & ' a X > + Clone + Default ,
102
+ X : Weight ,
103
103
{
104
104
let mut iter = weights. into_iter ( ) ;
105
105
let mut total_weight: X = iter. next ( ) . ok_or ( WeightedError :: NoItem ) ?. borrow ( ) . clone ( ) ;
106
106
107
- let zero = < X as Default > :: default ( ) ;
107
+ let zero = X :: ZERO ;
108
108
if !( total_weight >= zero) {
109
109
return Err ( WeightedError :: InvalidWeight ) ;
110
110
}
@@ -117,7 +117,10 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
117
117
return Err ( WeightedError :: InvalidWeight ) ;
118
118
}
119
119
weights. push ( total_weight. clone ( ) ) ;
120
- total_weight += w. borrow ( ) ;
120
+
121
+ if let Err ( ( ) ) = total_weight. checked_add_assign ( w. borrow ( ) ) {
122
+ return Err ( WeightedError :: Overflow ) ;
123
+ }
121
124
}
122
125
123
126
if total_weight == zero {
@@ -236,6 +239,60 @@ where X: SampleUniform + PartialOrd
236
239
}
237
240
}
238
241
242
+ /// Bounds on a weight
243
+ ///
244
+ /// See usage in [`WeightedIndex`].
245
+ pub trait Weight : Clone {
246
+ /// Representation of 0
247
+ const ZERO : Self ;
248
+
249
+ /// Checked addition
250
+ ///
251
+ /// - `Result::Ok`: On success, `v` is added to `self`
252
+ /// - `Result::Err`: Returns an error when `Self` cannot represent the
253
+ /// result of `self + v` (i.e. overflow). The value of `self` should be
254
+ /// discarded.
255
+ fn checked_add_assign ( & mut self , v : & Self ) -> Result < ( ) , ( ) > ;
256
+ }
257
+
258
+ macro_rules! impl_weight_int {
259
+ ( $t: ty) => {
260
+ impl Weight for $t {
261
+ const ZERO : Self = 0 ;
262
+ fn checked_add_assign( & mut self , v: & Self ) -> Result <( ) , ( ) > {
263
+ match self . checked_add( * v) {
264
+ Some ( sum) => {
265
+ * self = sum;
266
+ Ok ( ( ) )
267
+ }
268
+ None => Err ( ( ) ) ,
269
+ }
270
+ }
271
+ }
272
+ } ;
273
+ ( $t: ty, $( $tt: ty) ,* ) => {
274
+ impl_weight_int!( $t) ;
275
+ impl_weight_int!( $( $tt) ,* ) ;
276
+ }
277
+ }
278
+ impl_weight_int ! ( i8 , i16 , i32 , i64 , i128 , isize ) ;
279
+ impl_weight_int ! ( u8 , u16 , u32 , u64 , u128 , usize ) ;
280
+
281
+ macro_rules! impl_weight_float {
282
+ ( $t: ty) => {
283
+ impl Weight for $t {
284
+ const ZERO : Self = 0.0 ;
285
+ fn checked_add_assign( & mut self , v: & Self ) -> Result <( ) , ( ) > {
286
+ // Floats have an explicit representation for overflow
287
+ * self += * v;
288
+ Ok ( ( ) )
289
+ }
290
+ }
291
+ }
292
+ }
293
+ impl_weight_float ! ( f32 ) ;
294
+ impl_weight_float ! ( f64 ) ;
295
+
239
296
#[ cfg( test) ]
240
297
mod test {
241
298
use super :: * ;
@@ -388,12 +445,11 @@ mod test {
388
445
389
446
#[ test]
390
447
fn value_stability ( ) {
391
- fn test_samples < X : SampleUniform + PartialOrd , I > (
448
+ fn test_samples < X : Weight + SampleUniform + PartialOrd , I > (
392
449
weights : I , buf : & mut [ usize ] , expected : & [ usize ] ,
393
450
) where
394
451
I : IntoIterator ,
395
452
I :: Item : SampleBorrow < X > ,
396
- X : for < ' a > :: core:: ops:: AddAssign < & ' a X > + Clone + Default ,
397
453
{
398
454
assert_eq ! ( buf. len( ) , expected. len( ) ) ;
399
455
let distr = WeightedIndex :: new ( weights) . unwrap ( ) ;
@@ -420,6 +476,11 @@ mod test {
420
476
fn weighted_index_distributions_can_be_compared ( ) {
421
477
assert_eq ! ( WeightedIndex :: new( & [ 1 , 2 ] ) , WeightedIndex :: new( & [ 1 , 2 ] ) ) ;
422
478
}
479
+
480
+ #[ test]
481
+ fn overflow ( ) {
482
+ assert_eq ! ( WeightedIndex :: new( [ 2 , usize :: MAX ] ) , Err ( WeightedError :: Overflow ) ) ;
483
+ }
423
484
}
424
485
425
486
/// Error type returned from `WeightedIndex::new`.
@@ -438,6 +499,9 @@ pub enum WeightedError {
438
499
439
500
/// Too many weights are provided (length greater than `u32::MAX`)
440
501
TooMany ,
502
+
503
+ /// The sum of weights overflows
504
+ Overflow ,
441
505
}
442
506
443
507
#[ cfg( feature = "std" ) ]
@@ -450,6 +514,7 @@ impl fmt::Display for WeightedError {
450
514
WeightedError :: InvalidWeight => "A weight is invalid in distribution" ,
451
515
WeightedError :: AllWeightsZero => "All weights are zero in distribution" ,
452
516
WeightedError :: TooMany => "Too many weights (hit u32::MAX) in distribution" ,
517
+ WeightedError :: Overflow => "The sum of weights overflowed" ,
453
518
} )
454
519
}
455
520
}
0 commit comments