7
7
// except according to those terms.
8
8
9
9
use std:: ops:: { Add , Div , Mul } ;
10
- use libnum:: { self , One , Zero , Float } ;
10
+ use libnum:: { self , One , Zero , Float , FromPrimitive } ;
11
11
use itertools:: free:: enumerate;
12
12
13
13
use imp_prelude:: * ;
@@ -174,8 +174,11 @@ impl<A, S, D> ArrayBase<S, D>
174
174
/// n i=1
175
175
/// ```
176
176
///
177
- /// **Panics** if `ddof` is greater than or equal to the length of the
178
- /// axis, if `axis` is out of bounds, or if the length of the axis is zero.
177
+ /// and `n` is the length of the axis.
178
+ ///
179
+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
180
+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
181
+ /// numbers in the range `0..=n`.
179
182
///
180
183
/// # Example
181
184
///
@@ -190,27 +193,28 @@ impl<A, S, D> ArrayBase<S, D>
190
193
/// ```
191
194
pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
192
195
where
193
- A : Float ,
196
+ A : Float + FromPrimitive ,
194
197
D : RemoveAxis ,
195
198
{
196
- let mut count = A :: zero ( ) ;
199
+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
200
+ let n = A :: from_usize ( self . len_of ( axis) ) . expect ( "Converting length to `A` must not fail." ) ;
201
+ assert ! (
202
+ !( ddof < zero || ddof > n) ,
203
+ "`ddof` must not be less than zero or greater than the length of \
204
+ the axis",
205
+ ) ;
206
+ let dof = n - ddof;
197
207
let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
198
208
let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
199
- for subview in self . axis_iter ( axis) {
200
- count = count + A :: one ( ) ;
209
+ for ( i , subview) in self . axis_iter ( axis) . enumerate ( ) {
210
+ let count = A :: from_usize ( i + 1 ) . expect ( "Converting index to `A` must not fail." ) ;
201
211
azip ! ( mut mean, mut sum_sq, x ( subview) in {
202
212
let delta = x - * mean;
203
213
* mean = * mean + delta / count;
204
214
* sum_sq = ( x - * mean) . mul_add( delta, * sum_sq) ;
205
215
} ) ;
206
216
}
207
- if ddof >= count {
208
- panic ! ( "`ddof` needs to be strictly smaller than the length \
209
- of the axis you are computing the variance for!")
210
- } else {
211
- let dof = count - ddof;
212
- sum_sq. mapv_into ( |s| s / dof)
213
- }
217
+ sum_sq. mapv_into ( |s| s / dof)
214
218
}
215
219
216
220
/// Return standard deviation along `axis`.
@@ -238,8 +242,11 @@ impl<A, S, D> ArrayBase<S, D>
238
242
/// n i=1
239
243
/// ```
240
244
///
241
- /// **Panics** if `ddof` is greater than or equal to the length of the
242
- /// axis, if `axis` is out of bounds, or if the length of the axis is zero.
245
+ /// and `n` is the length of the axis.
246
+ ///
247
+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
248
+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
249
+ /// numbers in the range `0..=n`.
243
250
///
244
251
/// # Example
245
252
///
@@ -254,7 +261,7 @@ impl<A, S, D> ArrayBase<S, D>
254
261
/// ```
255
262
pub fn std_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
256
263
where
257
- A : Float ,
264
+ A : Float + FromPrimitive ,
258
265
D : RemoveAxis ,
259
266
{
260
267
self . var_axis ( axis, ddof) . mapv_into ( |x| x. sqrt ( ) )
0 commit comments