Skip to content

Commit e1a6db8

Browse files
authored
Merge pull request #515 from jturner314/change-var-std-bounds
Adjust constraints on ddof for var_axis and std_axis
2 parents 899dca0 + 4f06af5 commit e1a6db8

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

src/numeric/impl_numeric.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use std::ops::{Add, Div, Mul};
10-
use libnum::{self, One, Zero, Float};
10+
use libnum::{self, One, Zero, Float, FromPrimitive};
1111
use itertools::free::enumerate;
1212

1313
use imp_prelude::*;
@@ -174,8 +174,11 @@ impl<A, S, D> ArrayBase<S, D>
174174
/// n i=1
175175
/// ```
176176
///
177-
/// **Panics** if `ddof` is greater than or equal to the length of the
178-
/// axis, if `axis` is out of bounds, or if the length of the axis is zero.
177+
/// and `n` is the length of the axis.
178+
///
179+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
180+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
181+
/// numbers in the range `0..=n`.
179182
///
180183
/// # Example
181184
///
@@ -190,27 +193,28 @@ impl<A, S, D> ArrayBase<S, D>
190193
/// ```
191194
pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
192195
where
193-
A: Float,
196+
A: Float + FromPrimitive,
194197
D: RemoveAxis,
195198
{
196-
let mut count = A::zero();
199+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
200+
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
201+
assert!(
202+
!(ddof < zero || ddof > n),
203+
"`ddof` must not be less than zero or greater than the length of \
204+
the axis",
205+
);
206+
let dof = n - ddof;
197207
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
198208
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
199-
for subview in self.axis_iter(axis) {
200-
count = count + A::one();
209+
for (i, subview) in self.axis_iter(axis).enumerate() {
210+
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
201211
azip!(mut mean, mut sum_sq, x (subview) in {
202212
let delta = x - *mean;
203213
*mean = *mean + delta / count;
204214
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
205215
});
206216
}
207-
if ddof >= count {
208-
panic!("`ddof` needs to be strictly smaller than the length \
209-
of the axis you are computing the variance for!")
210-
} else {
211-
let dof = count - ddof;
212-
sum_sq.mapv_into(|s| s / dof)
213-
}
217+
sum_sq.mapv_into(|s| s / dof)
214218
}
215219

216220
/// Return standard deviation along `axis`.
@@ -238,8 +242,11 @@ impl<A, S, D> ArrayBase<S, D>
238242
/// n i=1
239243
/// ```
240244
///
241-
/// **Panics** if `ddof` is greater than or equal to the length of the
242-
/// axis, if `axis` is out of bounds, or if the length of the axis is zero.
245+
/// and `n` is the length of the axis.
246+
///
247+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
248+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
249+
/// numbers in the range `0..=n`.
243250
///
244251
/// # Example
245252
///
@@ -254,7 +261,7 @@ impl<A, S, D> ArrayBase<S, D>
254261
/// ```
255262
pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
256263
where
257-
A: Float,
264+
A: Float + FromPrimitive,
258265
D: RemoveAxis,
259266
{
260267
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())

tests/array.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -904,16 +904,32 @@ fn std_axis() {
904904

905905
#[test]
906906
#[should_panic]
907-
fn var_axis_bad_dof() {
907+
fn var_axis_negative_ddof() {
908908
let a = array![1., 2., 3.];
909-
a.var_axis(Axis(0), 4.);
909+
a.var_axis(Axis(0), -1.);
910910
}
911911

912912
#[test]
913913
#[should_panic]
914+
fn var_axis_too_large_ddof() {
915+
let a = array![1., 2., 3.];
916+
a.var_axis(Axis(0), 4.);
917+
}
918+
919+
#[test]
920+
fn var_axis_nan_ddof() {
921+
let a = Array2::<f64>::zeros((2, 3));
922+
let v = a.var_axis(Axis(1), ::std::f64::NAN);
923+
assert_eq!(v.shape(), &[2]);
924+
v.mapv(|x| assert!(x.is_nan()));
925+
}
926+
927+
#[test]
914928
fn var_axis_empty_axis() {
915-
let a = array![[], []];
916-
a.var_axis(Axis(1), 0.);
929+
let a = Array2::<f64>::zeros((2, 0));
930+
let v = a.var_axis(Axis(1), 0.);
931+
assert_eq!(v.shape(), &[2]);
932+
v.mapv(|x| assert!(x.is_nan()));
917933
}
918934

919935
#[test]
@@ -924,10 +940,11 @@ fn std_axis_bad_dof() {
924940
}
925941

926942
#[test]
927-
#[should_panic]
928943
fn std_axis_empty_axis() {
929-
let a = array![[], []];
930-
a.std_axis(Axis(1), 0.);
944+
let a = Array2::<f64>::zeros((2, 0));
945+
let v = a.std_axis(Axis(1), 0.);
946+
assert_eq!(v.shape(), &[2]);
947+
v.mapv(|x| assert!(x.is_nan()));
931948
}
932949

933950
#[test]

0 commit comments

Comments
 (0)