Skip to content

Commit 6060335

Browse files
committed
Adjust constraints on ddof for var_axis and std_axis
Requiring `ddof` to not be < 0 allows us to ignore weird edge cases. Allowing `ddof` to be equal to the length of the axis allows computation of the population variance (`ddof` = 0) for a zero-length axis and computation of the sample variance (`ddof` = 1) for an axis of length one. (The result in these cases is an array of NaNs.) Note that this is a breaking change because of the new constraint that `ddof` must not be less than zero.
1 parent 7d35df4 commit 6060335

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

src/numeric/impl_numeric.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ impl<A, S, D> ArrayBase<S, D>
163163
/// n i=1
164164
/// ```
165165
///
166-
/// **Panics** if `ddof` is greater than or equal to the length of the
167-
/// axis or if `axis` is out of bounds.
166+
/// **Panics** if `ddof` is less than zero or greater than the length of
167+
/// the axis or if `axis` is out of bounds.
168168
///
169169
/// # Example
170170
///
@@ -193,13 +193,13 @@ impl<A, S, D> ArrayBase<S, D>
193193
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
194194
});
195195
}
196-
if ddof >= count {
197-
panic!("`ddof` needs to be strictly smaller than the length \
198-
of the axis you are computing the variance for!")
199-
} else {
200-
let dof = count - ddof;
201-
sum_sq.mapv_into(|s| s / dof)
202-
}
196+
assert!(
197+
!(ddof < A::zero() || ddof > count),
198+
"`ddof` must not be less than zero or greater than the length of \
199+
the axis",
200+
);
201+
let dof = count - ddof;
202+
sum_sq.mapv_into(|s| s / dof)
203203
}
204204

205205
/// Return standard deviation along `axis`.
@@ -227,8 +227,8 @@ impl<A, S, D> ArrayBase<S, D>
227227
/// n i=1
228228
/// ```
229229
///
230-
/// **Panics** if `ddof` is greater than or equal to the length of the
231-
/// axis or if `axis` is out of bounds.
230+
/// **Panics** if `ddof` is less than zero or greater than the length of
231+
/// the axis or if `axis` is out of bounds.
232232
///
233233
/// # Example
234234
///

tests/array.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,15 +810,32 @@ fn std_axis() {
810810

811811
#[test]
812812
#[should_panic]
813-
fn var_axis_bad_dof() {
813+
fn var_axis_negative_ddof() {
814+
let a = array![1., 2., 3.];
815+
a.var_axis(Axis(0), -1.);
816+
}
817+
818+
#[test]
819+
#[should_panic]
820+
fn var_axis_too_large_ddof() {
814821
let a = array![1., 2., 3.];
815822
a.var_axis(Axis(0), 4.);
816823
}
817824

825+
#[test]
826+
fn var_axis_nan_ddof() {
827+
let a = Array2::<f64>::zeros((2, 3));
828+
let v = a.var_axis(Axis(1), ::std::f64::NAN);
829+
assert_eq!(v.shape(), &[2]);
830+
v.mapv(|x| assert!(x.is_nan()));
831+
}
832+
818833
#[test]
819834
fn var_axis_empty_axis() {
820-
let a = array![[], []];
821-
a.var_axis(Axis(1), -1.);
835+
let a = Array2::<f64>::zeros((2, 0));
836+
let v = a.var_axis(Axis(1), 0.);
837+
assert_eq!(v.shape(), &[2]);
838+
v.mapv(|x| assert!(x.is_nan()));
822839
}
823840

824841
#[test]
@@ -830,8 +847,10 @@ fn std_axis_bad_dof() {
830847

831848
#[test]
832849
fn std_axis_empty_axis() {
833-
let a = array![[], []];
834-
a.std_axis(Axis(1), -1.);
850+
let a = Array2::<f64>::zeros((2, 0));
851+
let v = a.std_axis(Axis(1), 0.);
852+
assert_eq!(v.shape(), &[2]);
853+
v.mapv(|x| assert!(x.is_nan()));
835854
}
836855

837856
#[test]

0 commit comments

Comments
 (0)