Skip to content

Commit 6e37c46

Browse files
LukeMathWalkerjturner314
authored andcommitted
Make mean_axis return Option, like mean does
1 parent 3298cd5 commit 6e37c46

File tree

4 files changed

+24
-19
lines changed

4 files changed

+24
-19
lines changed

examples/column_standardize.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ fn main() {
2323
[ 2., 2., 2.]];
2424

2525
println!("{:8.4}", data);
26-
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)));
26+
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap());
2727

28-
data -= &data.mean_axis(Axis(0));
28+
data -= &data.mean_axis(Axis(0)).unwrap();
2929
println!("{:8.4}", data);
3030

3131
data /= &std(&data, Axis(0));

src/numeric/impl_numeric.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ impl<A, S, D> ArrayBase<S, D>
150150

151151
/// Return mean along `axis`.
152152
///
153-
/// **Panics** if `axis` is out of bounds, if the length of the axis is
154-
/// zero and division by zero panics for type `A`, or if `A::from_usize()`
153+
/// Return `None` if the length of the axis is zero.
154+
///
155+
/// **Panics** if `axis` is out of bounds or if `A::from_usize()`
155156
/// fails for the axis length.
156157
///
157158
/// ```
@@ -160,19 +161,25 @@ impl<A, S, D> ArrayBase<S, D>
160161
/// let a = arr2(&[[1., 2., 3.],
161162
/// [4., 5., 6.]]);
162163
/// assert!(
163-
/// a.mean_axis(Axis(0)) == aview1(&[2.5, 3.5, 4.5]) &&
164-
/// a.mean_axis(Axis(1)) == aview1(&[2., 5.]) &&
164+
/// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
165+
/// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
165166
///
166-
/// a.mean_axis(Axis(0)).mean_axis(Axis(0)) == aview0(&3.5)
167+
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
167168
/// );
168169
/// ```
169-
pub fn mean_axis(&self, axis: Axis) -> Array<A, D::Smaller>
170+
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
170171
where A: Clone + Zero + FromPrimitive + Add<Output=A> + Div<Output=A>,
171172
D: RemoveAxis,
172173
{
173-
let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail.");
174-
let sum = self.sum_axis(axis);
175-
sum / &aview0(&n)
174+
let axis_length = self.len_of(axis);
175+
if axis_length == 0 {
176+
None
177+
} else {
178+
let axis_length = A::from_usize(axis_length)
179+
.expect("Converting axis length to `A` must not fail.");
180+
let sum = self.sum_axis(axis);
181+
Some(sum / &aview0(&axis_length))
182+
}
176183
}
177184

178185
/// Return variance along `axis`.

tests/array.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -931,10 +931,10 @@ fn sum_mean()
931931
let a = arr2(&[[1., 2.], [3., 4.]]);
932932
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
933933
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
934-
assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.]));
935-
assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5]));
934+
assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.])));
935+
assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5])));
936936
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
937-
assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5]));
937+
assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5]));
938938
assert_eq!(a.sum(), 10.);
939939
}
940940

@@ -947,11 +947,9 @@ fn sum_mean_empty() {
947947
Array::zeros((2, 3)),
948948
);
949949
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
950-
assert_eq!(a.shape(), &[]);
951-
assert!(a[()].is_nan());
950+
assert_eq!(a, None);
952951
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
953-
assert_eq!(a.shape(), &[2, 3]);
954-
a.mapv(|x| assert!(x.is_nan()));
952+
assert_eq!(a, None);
955953
}
956954

957955
#[test]

tests/complex.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ fn complex_mat_mul()
2222
let r = a.dot(&e);
2323
println!("{}", a);
2424
assert_eq!(r, a);
25-
assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
25+
assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
2626
}

0 commit comments

Comments
 (0)