Skip to content

Commit fbac78b

Browse files
committed
Merge pull request #486 from LukeMathWalker/std_axis
2 parents db7a6f2 + fd1f541 commit fbac78b

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

src/numeric/impl_numeric.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,53 @@ impl<A, S, D> ArrayBase<S, D>
181181
}
182182
}
183183

184+
/// Return standard deviation along `axis`.
185+
///
186+
/// The standard deviation is computed from the variance using
187+
/// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
188+
///
189+
/// The parameter `ddof` specifies the "delta degrees of freedom". For
190+
/// example, to calculate the population standard deviation, use `ddof = 0`,
191+
/// or to calculate the sample standard deviation, use `ddof = 1`.
192+
///
193+
/// The standard deviation is defined as:
194+
///
195+
/// ```text
196+
/// 1 n
197+
/// stddev = sqrt ( ―――――――― ∑ (xᵢ - x̅)² )
198+
/// n - ddof i=1
199+
/// ```
200+
///
201+
/// where
202+
///
203+
/// ```text
204+
/// 1 n
205+
/// x̅ = ― ∑ xᵢ
206+
/// n i=1
207+
/// ```
208+
///
209+
/// **Panics** if `ddof` is greater than or equal to the length of the
210+
/// axis, if `axis` is out of bounds, or if the length of the axis is zero.
211+
///
212+
/// # Example
213+
///
214+
/// ```
215+
/// use ndarray::{aview1, arr2, Axis};
216+
///
217+
/// let a = arr2(&[[1., 2.],
218+
/// [3., 4.],
219+
/// [5., 6.]]);
220+
/// let stddev = a.std_axis(Axis(0), 1.);
221+
/// assert_eq!(stddev, aview1(&[2., 2.]));
222+
/// ```
223+
pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
224+
where
225+
A: Float,
226+
D: RemoveAxis,
227+
{
228+
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
229+
}
230+
184231
/// Return `true` if the arrays' elementwise differences are all within
185232
/// the given absolute tolerance, `false` otherwise.
186233
///

tests/array.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,53 @@ fn var_axis() {
745745
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
746746
}
747747

748+
#[test]
749+
fn std_axis() {
750+
let a = array![
751+
[
752+
[ 0.22935481, 0.08030619, 0.60827517, 0.73684379],
753+
[ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ],
754+
[ 0.44485313, 0.63316367, 0.11005111, 0.08656246]
755+
],
756+
[
757+
[ 0.28924665, 0.44082454, 0.59837736, 0.41014531],
758+
[ 0.08382316, 0.43259439, 0.1428889 , 0.44830176],
759+
[ 0.51529756, 0.70111616, 0.20799415, 0.91851457]
760+
],
761+
];
762+
assert!(a.std_axis(Axis(0), 1.5).all_close(
763+
&aview2(&[
764+
[ 0.05989184, 0.36051836, 0.00989781, 0.32669847],
765+
[ 0.81957535, 0.39599997, 0.49731472, 0.17084346],
766+
[ 0.07044443, 0.06795249, 0.09794304, 0.83195211],
767+
]),
768+
1e-4,
769+
));
770+
assert!(a.std_axis(Axis(1), 1.7).all_close(
771+
&aview2(&[
772+
[ 0.42698655, 0.48139215, 0.36874991, 0.41458724],
773+
[ 0.26769097, 0.18941435, 0.30555015, 0.35118674],
774+
]),
775+
1e-8,
776+
));
777+
assert!(a.std_axis(Axis(2), 2.3).all_close(
778+
&aview2(&[
779+
[ 0.41117907, 0.37130425, 0.35332388],
780+
[ 0.16905862, 0.25304841, 0.39978276],
781+
]),
782+
1e-8,
783+
));
784+
785+
let b = array![[100000., 1., 0.01]];
786+
assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
787+
assert!(
788+
b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6),
789+
);
790+
791+
let c = array![[], []];
792+
assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[]));
793+
}
794+
748795
#[test]
749796
#[should_panic]
750797
fn var_axis_bad_dof() {
@@ -759,6 +806,20 @@ fn var_axis_empty_axis() {
759806
a.var_axis(Axis(1), 0.);
760807
}
761808

809+
#[test]
810+
#[should_panic]
811+
fn std_axis_bad_dof() {
812+
let a = array![1., 2., 3.];
813+
a.std_axis(Axis(0), 4.);
814+
}
815+
816+
#[test]
817+
#[should_panic]
818+
fn std_axis_empty_axis() {
819+
let a = array![[], []];
820+
a.std_axis(Axis(1), 0.);
821+
}
822+
762823
#[test]
763824
fn iter_size_hint()
764825
{

0 commit comments

Comments
 (0)