Skip to content

Commit 4f06af5

Browse files
committed
Switch var_axis and std_axis to use FromPrimitive
The documentation for the `Zero` and `One` traits says only that they are the additive and multiplicative identities; it doesn't say anything about converting an integer to a float by adding `One::one()` to `Zero::zero()` repeatedly. Additionally, it's nice to panic early instead of waiting until after the sum has been calculated.
1 parent 6060335 commit 4f06af5

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

src/numeric/impl_numeric.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use std::ops::{Add, Div, Mul};
10-
use libnum::{self, One, Zero, Float};
10+
use libnum::{self, One, Zero, Float, FromPrimitive};
1111
use itertools::free::enumerate;
1212

1313
use imp_prelude::*;
@@ -163,8 +163,11 @@ impl<A, S, D> ArrayBase<S, D>
163163
/// n i=1
164164
/// ```
165165
///
166-
/// **Panics** if `ddof` is less than zero or greater than the length of
167-
/// the axis or if `axis` is out of bounds.
166+
/// and `n` is the length of the axis.
167+
///
168+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
169+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
170+
/// numbers in the range `0..=n`.
168171
///
169172
/// # Example
170173
///
@@ -179,26 +182,27 @@ impl<A, S, D> ArrayBase<S, D>
179182
/// ```
180183
pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
181184
where
182-
A: Float,
185+
A: Float + FromPrimitive,
183186
D: RemoveAxis,
184187
{
185-
let mut count = A::zero();
188+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
189+
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
190+
assert!(
191+
!(ddof < zero || ddof > n),
192+
"`ddof` must not be less than zero or greater than the length of \
193+
the axis",
194+
);
195+
let dof = n - ddof;
186196
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
187197
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
188-
for subview in self.axis_iter(axis) {
189-
count = count + A::one();
198+
for (i, subview) in self.axis_iter(axis).enumerate() {
199+
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
190200
azip!(mut mean, mut sum_sq, x (subview) in {
191201
let delta = x - *mean;
192202
*mean = *mean + delta / count;
193203
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
194204
});
195205
}
196-
assert!(
197-
!(ddof < A::zero() || ddof > count),
198-
"`ddof` must not be less than zero or greater than the length of \
199-
the axis",
200-
);
201-
let dof = count - ddof;
202206
sum_sq.mapv_into(|s| s / dof)
203207
}
204208

@@ -227,8 +231,11 @@ impl<A, S, D> ArrayBase<S, D>
227231
/// n i=1
228232
/// ```
229233
///
230-
/// **Panics** if `ddof` is less than zero or greater than the length of
231-
/// the axis or if `axis` is out of bounds.
234+
/// and `n` is the length of the axis.
235+
///
236+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
237+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
238+
/// numbers in the range `0..=n`.
232239
///
233240
/// # Example
234241
///
@@ -243,7 +250,7 @@ impl<A, S, D> ArrayBase<S, D>
243250
/// ```
244251
pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
245252
where
246-
A: Float,
253+
A: Float + FromPrimitive,
247254
D: RemoveAxis,
248255
{
249256
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())

0 commit comments

Comments
 (0)