Skip to content

Commit 9e4ac27

Browse files
committed
deviation: Implement count_eq
1 parent 6722934 commit 9e4ac27

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/deviation.rs

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
3+
/// Extension trait for `ArrayBase` providing functions
4+
/// to compute different deviation measures.
5+
pub trait DeviationExt<A, S, D>
6+
where
7+
S: Data<Elem = A>,
8+
D: Dimension,
9+
{
10+
fn count_eq(&self, other: &ArrayBase<S, D>) -> usize
11+
where
12+
A: PartialEq;
13+
}
14+
15+
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
16+
where
17+
S: Data<Elem = A>,
18+
D: Dimension,
19+
{
20+
fn count_eq(&self, other: &ArrayBase<S, D>) -> usize
21+
where
22+
A: PartialEq,
23+
{
24+
let mut c = 0;
25+
26+
Zip::from(self).and(other).apply(|a, b| {
27+
if a == b {
28+
c += 1;
29+
}
30+
});
31+
32+
c
33+
}
34+
}
35+
36+
#[cfg(test)]
37+
mod tests {
38+
use super::*;
39+
use ndarray::array;
40+
41+
#[test]
42+
fn test_count_eq() {
43+
let a = array![1., 2., 3., 4., 5., 6., 7.];
44+
let b = array![1., 3., 3., 4., 6., 7., 8.];
45+
let c = array![2., 4., 4., 5., 7., 8., 9.];
46+
let d = array![[1, 2], [3, 4], [5, 6]];
47+
let e = array![[1, 2], [4, 3], [5, 6]];
48+
49+
assert_eq!(a.count_eq(&b), 3);
50+
assert_eq!(b.count_eq(&c), 0);
51+
assert_eq!(d.count_eq(&e), 4);
52+
}
53+
}

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
//! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/
2929
3030
pub use crate::correlation::CorrelationExt;
31+
pub use crate::deviation::DeviationExt;
3132
pub use crate::entropy::EntropyExt;
3233
pub use crate::histogram::HistogramExt;
3334
pub use crate::maybe_nan::{MaybeNan, MaybeNanExt};
@@ -69,6 +70,7 @@ mod private {
6970
}
7071

7172
mod correlation;
73+
mod deviation;
7274
mod entropy;
7375
pub mod errors;
7476
pub mod histogram;

0 commit comments

Comments
 (0)