|
48 | 48 | fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A
|
49 | 49 | where
|
50 | 50 | A: AddAssign + Clone + FromPrimitive + Signed;
|
| 51 | + |
| 52 | + fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A |
| 53 | + where |
| 54 | + A: AddAssign + Clone + FromPrimitive + Signed + Float; |
51 | 55 | }
|
52 | 56 |
|
53 | 57 | impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
|
@@ -165,15 +169,24 @@ where
|
165 | 169 | {
|
166 | 170 | self.sq_l2_dist(other) / A::from_usize(self.len()).unwrap()
|
167 | 171 | }
|
| 172 | + |
| 173 | + fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> A |
| 174 | + where |
| 175 | + A: AddAssign + Clone + FromPrimitive + Signed + Float |
| 176 | + { |
| 177 | + self.mean_sq_dev(other).sqrt() |
| 178 | + } |
168 | 179 | }
|
169 | 180 |
|
170 | 181 | #[cfg(test)]
|
171 | 182 | mod tests {
|
172 | 183 | use super::*;
|
| 184 | + use approx::assert_abs_diff_eq; |
173 | 185 | use ndarray::*;
|
174 | 186 | use ndarray_rand::RandomExt;
|
175 | 187 | use num_traits::Pow;
|
176 | 188 | use rand::distributions::Uniform;
|
| 189 | + use std::f64; |
177 | 190 |
|
178 | 191 | #[test]
|
179 | 192 | fn test_count_eq() {
|
@@ -304,4 +317,14 @@ mod tests {
|
304 | 317 | assert_eq!(a.mean_sq_dev(&b), 10.);
|
305 | 318 | assert_eq!(b.mean_sq_dev(&a), 10.);
|
306 | 319 | }
|
| 320 | + |
| 321 | + #[test] |
| 322 | + fn test_root_mean_sq_dev() { |
| 323 | + let a = array![1., 1.]; |
| 324 | + let b = array![3., 5.]; |
| 325 | + |
| 326 | + assert_eq!(a.root_mean_sq_dev(&a), 0.); |
| 327 | + assert_abs_diff_eq!(a.root_mean_sq_dev(&b), 10.0.sqrt(), epsilon = f64::EPSILON); |
| 328 | + assert_abs_diff_eq!(b.root_mean_sq_dev(&a), 10.0.sqrt(), epsilon = f64::EPSILON); |
| 329 | + } |
307 | 330 | }
|
0 commit comments