7
7
// except according to those terms.
8
8
9
9
use std:: ops:: { Add , Div , Mul } ;
10
- use libnum:: { self , One , Zero , Float } ;
10
+ use libnum:: { self , One , Zero , Float , FromPrimitive } ;
11
11
use itertools:: free:: enumerate;
12
12
13
13
use imp_prelude:: * ;
@@ -163,8 +163,11 @@ impl<A, S, D> ArrayBase<S, D>
163
163
/// n i=1
164
164
/// ```
165
165
///
166
- /// **Panics** if `ddof` is less than zero or greater than the length of
167
- /// the axis or if `axis` is out of bounds.
166
+ /// and `n` is the length of the axis.
167
+ ///
168
+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
169
+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
170
+ /// numbers in the range `0..=n`.
168
171
///
169
172
/// # Example
170
173
///
@@ -179,26 +182,27 @@ impl<A, S, D> ArrayBase<S, D>
179
182
/// ```
180
183
pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
181
184
where
182
- A : Float ,
185
+ A : Float + FromPrimitive ,
183
186
D : RemoveAxis ,
184
187
{
185
- let mut count = A :: zero ( ) ;
188
+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
189
+ let n = A :: from_usize ( self . len_of ( axis) ) . expect ( "Converting length to `A` must not fail." ) ;
190
+ assert ! (
191
+ !( ddof < zero || ddof > n) ,
192
+ "`ddof` must not be less than zero or greater than the length of \
193
+ the axis",
194
+ ) ;
195
+ let dof = n - ddof;
186
196
let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
187
197
let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
188
- for subview in self . axis_iter ( axis) {
189
- count = count + A :: one ( ) ;
198
+ for ( i , subview) in self . axis_iter ( axis) . enumerate ( ) {
199
+ let count = A :: from_usize ( i + 1 ) . expect ( "Converting index to `A` must not fail." ) ;
190
200
azip ! ( mut mean, mut sum_sq, x ( subview) in {
191
201
let delta = x - * mean;
192
202
* mean = * mean + delta / count;
193
203
* sum_sq = ( x - * mean) . mul_add( delta, * sum_sq) ;
194
204
} ) ;
195
205
}
196
- assert ! (
197
- !( ddof < A :: zero( ) || ddof > count) ,
198
- "`ddof` must not be less than zero or greater than the length of \
199
- the axis",
200
- ) ;
201
- let dof = count - ddof;
202
206
sum_sq. mapv_into ( |s| s / dof)
203
207
}
204
208
@@ -227,8 +231,11 @@ impl<A, S, D> ArrayBase<S, D>
227
231
/// n i=1
228
232
/// ```
229
233
///
230
- /// **Panics** if `ddof` is less than zero or greater than the length of
231
- /// the axis or if `axis` is out of bounds.
234
+ /// and `n` is the length of the axis.
235
+ ///
236
+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
237
+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
238
+ /// numbers in the range `0..=n`.
232
239
///
233
240
/// # Example
234
241
///
@@ -243,7 +250,7 @@ impl<A, S, D> ArrayBase<S, D>
243
250
/// ```
244
251
pub fn std_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
245
252
where
246
- A : Float ,
253
+ A : Float + FromPrimitive ,
247
254
D : RemoveAxis ,
248
255
{
249
256
self . var_axis ( axis, ddof) . mapv_into ( |x| x. sqrt ( ) )
0 commit comments