Skip to content

Commit cb6694b

Browse files
committed
deviation: Implement root_mean_sq_dev (no normalize param)
1 parent 8ba9539 commit cb6694b

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

src/deviation.rs

+23
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ where
4848
fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A
4949
where
5050
A: AddAssign + Clone + FromPrimitive + Signed;
51+
52+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A
53+
where
54+
A: AddAssign + Clone + FromPrimitive + Signed + Float;
5155
}
5256

5357
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
@@ -165,15 +169,24 @@ where
165169
{
166170
self.sq_l2_dist(other) / A::from_usize(self.len()).unwrap()
167171
}
172+
173+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A
174+
where
175+
A: AddAssign + Clone + FromPrimitive + Signed + Float
176+
{
177+
self.mean_sq_dev(other).sqrt()
178+
}
168179
}
169180

170181
#[cfg(test)]
171182
mod tests {
172183
use super::*;
184+
use approx::assert_abs_diff_eq;
173185
use ndarray::*;
174186
use ndarray_rand::RandomExt;
175187
use num_traits::Pow;
176188
use rand::distributions::Uniform;
189+
use std::f64;
177190

178191
#[test]
179192
fn test_count_eq() {
@@ -304,4 +317,14 @@ mod tests {
304317
assert_eq!(a.mean_sq_dev(&b), 10.);
305318
assert_eq!(b.mean_sq_dev(&a), 10.);
306319
}
320+
321+
#[test]
322+
fn test_root_mean_sq_dev() {
323+
let a = array![1., 1.];
324+
let b = array![3., 5.];
325+
326+
assert_eq!(a.root_mean_sq_dev(&a), 0.);
327+
assert_abs_diff_eq!(a.root_mean_sq_dev(&b), 10.0.sqrt(), epsilon = f64::EPSILON);
328+
assert_abs_diff_eq!(b.root_mean_sq_dev(&a), 10.0.sqrt(), epsilon = f64::EPSILON);
329+
}
307330
}

0 commit comments

Comments
 (0)