Skip to content

Commit 58627a3

Browse files
committed
Port deviation functions from StatsBase.jl
1 parent 6f898f6 commit 58627a3

File tree

5 files changed

+635
-2
lines changed

5 files changed

+635
-2
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ quickcheck = { version = "0.8.1", default-features = false }
3030
ndarray-rand = "0.9"
3131
approx = "0.3"
3232
quickcheck_macros = "0.8"
33+
num-bigint = "0.2.2"
3334

3435
[[bench]]
3536
name = "sort"

src/deviation.rs

+378
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
use num_traits::{Signed, ToPrimitive};
3+
use std::convert::Into;
4+
use std::ops::AddAssign;
5+
6+
use crate::errors::{MultiInputError, ShapeMismatch};
7+
8+
/// An extension trait for `ArrayBase` providing functions
9+
/// to compute different deviation measures.
10+
pub trait DeviationExt<A, S, D>
11+
where
12+
S: Data<Elem = A>,
13+
D: Dimension,
14+
{
15+
/// Counts the number of indices at which the elements of the arrays `self`
16+
/// and `other` are equal.
17+
///
18+
/// The following **errors** may be returned:
19+
///
20+
/// * `MultiInputError::EmptyInput` if `self` is empty
21+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
22+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
23+
where
24+
A: PartialEq;
25+
26+
/// Counts the number of indices at which the elements of the arrays `self`
27+
/// and `other` are not equal.
28+
///
29+
/// The following **errors** may be returned:
30+
///
31+
/// * `MultiInputError::EmptyInput` if `self` is empty
32+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
33+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
34+
where
35+
A: PartialEq;
36+
37+
/// Computes the [squared L2 distance] between `self` and `other`.
38+
///
39+
/// ```text
40+
/// n
41+
/// ∑ |aᵢ - bᵢ|²
42+
/// i=1
43+
/// ```
44+
///
45+
/// where `self` is `a` and `other` is `b`.
46+
///
47+
/// The following **errors** may be returned:
48+
///
49+
/// * `MultiInputError::EmptyInput` if `self` is empty
50+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
51+
///
52+
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
53+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
54+
where
55+
A: AddAssign + Clone + Signed;
56+
57+
/// Computes the [L2 distance] between `self` and `other`.
58+
///
59+
/// ```text
60+
/// n
61+
/// √ ∑ |aᵢ - bᵢ|²
62+
/// i=1
63+
/// ```
64+
///
65+
/// where `self` is `a` and `other` is `b`.
66+
///
67+
/// The following **errors** may be returned:
68+
///
69+
/// * `MultiInputError::EmptyInput` if `self` is empty
70+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
71+
///
72+
/// **Panics** if the type cast from `A` to `f64` fails.
73+
///
74+
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
75+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
76+
where
77+
A: AddAssign + Clone + Signed + ToPrimitive;
78+
79+
/// Computes the [L1 distance] between `self` and `other`.
80+
///
81+
/// ```text
82+
/// n
83+
/// ∑ |aᵢ - bᵢ|
84+
/// i=1
85+
/// ```
86+
///
87+
/// where `self` is `a` and `other` is `b`.
88+
///
89+
/// The following **errors** may be returned:
90+
///
91+
/// * `MultiInputError::EmptyInput` if `self` is empty
92+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
93+
///
94+
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
95+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
96+
where
97+
A: AddAssign + Clone + Signed;
98+
99+
/// Computes the [L∞ distance] between `self` and `other`.
100+
///
101+
/// ```text
102+
/// max(|aᵢ - bᵢ|)
103+
/// ᵢ
104+
/// ```
105+
///
106+
/// where `self` is `a` and `other` is `b`.
107+
///
108+
/// The following **errors** may be returned:
109+
///
110+
/// * `MultiInputError::EmptyInput` if `self` is empty
111+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
112+
///
113+
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
114+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
115+
where
116+
A: Clone + PartialOrd + Signed;
117+
118+
/// Computes the [mean absolute error] between `self` and `other`.
119+
///
120+
/// ```text
121+
/// n
122+
/// 1/n * ∑ |aᵢ - bᵢ|
123+
/// i=1
124+
/// ```
125+
///
126+
/// where `self` is `a` and `other` is `b`.
127+
///
128+
/// The following **errors** may be returned:
129+
///
130+
/// * `MultiInputError::EmptyInput` if `self` is empty
131+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
132+
///
133+
/// **Panics** if the type cast from `A` to `f64` fails.
134+
///
135+
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
136+
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
137+
where
138+
A: AddAssign + Clone + Signed + ToPrimitive;
139+
140+
/// Computes the [mean squared error] between `self` and `other`.
141+
///
142+
/// ```text
143+
/// n
144+
/// 1/n * ∑ |aᵢ - bᵢ|²
145+
/// i=1
146+
/// ```
147+
///
148+
/// where `self` is `a` and `other` is `b`.
149+
///
150+
/// The following **errors** may be returned:
151+
///
152+
/// * `MultiInputError::EmptyInput` if `self` is empty
153+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
154+
///
155+
/// **Panics** if the type cast from `A` to `f64` fails.
156+
///
157+
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
158+
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
159+
where
160+
A: AddAssign + Clone + Signed + ToPrimitive;
161+
162+
/// Computes the unnormalized [root-mean-square error] between `self` and `other`.
163+
///
164+
/// ```text
165+
/// √ mse(a, b)
166+
/// ```
167+
///
168+
/// where `self` is `a` and `other` is `b`.
169+
///
170+
/// where `mse` is the mean-squared-error.
171+
///
172+
/// The following **errors** may be returned:
173+
///
174+
/// * `MultiInputError::EmptyInput` if `self` is empty
175+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
176+
///
177+
/// **Panics** if the type cast from `A` to `f64` fails.
178+
///
179+
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
180+
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
181+
where
182+
A: AddAssign + Clone + Signed + ToPrimitive;
183+
184+
/// Computes the [peak signal-to-noise ratio] between `self` and `other`.
185+
///
186+
/// ```text
187+
/// 10 * log10(maxv^2 / mse(a, b))
188+
/// ```
189+
///
190+
/// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
191+
/// and `maxv` is the maximum possible value either array can take.
192+
///
193+
/// The following **errors** may be returned:
194+
///
195+
/// * `MultiInputError::EmptyInput` if `self` is empty
196+
/// * `ShapeMismatch` if `self` and `other` don't have the same shape
197+
///
198+
/// **Panics** if the type cast from `A` to `f64` fails.
199+
///
200+
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
201+
fn peak_signal_to_noise_ratio(
202+
&self,
203+
other: &ArrayBase<S, D>,
204+
maxv: A,
205+
) -> Result<f64, MultiInputError>
206+
where
207+
A: AddAssign + Clone + Signed + ToPrimitive;
208+
209+
private_decl! {}
210+
}
211+
212+
macro_rules! return_err_if_empty {
213+
($arr:expr) => {
214+
if $arr.len() == 0 {
215+
return Err(MultiInputError::EmptyInput);
216+
}
217+
};
218+
}
219+
macro_rules! return_err_unless_same_shape {
220+
($arr_a:expr, $arr_b:expr) => {
221+
if $arr_a.shape() != $arr_b.shape() {
222+
return Err(ShapeMismatch {
223+
first_shape: $arr_a.shape().to_vec(),
224+
second_shape: $arr_b.shape().to_vec(),
225+
}
226+
.into());
227+
}
228+
};
229+
}
230+
231+
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
232+
where
233+
S: Data<Elem = A>,
234+
D: Dimension,
235+
{
236+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
237+
where
238+
A: PartialEq,
239+
{
240+
return_err_if_empty!(self);
241+
return_err_unless_same_shape!(self, other);
242+
243+
let mut count = 0;
244+
245+
Zip::from(self).and(other).apply(|a, b| {
246+
if a == b {
247+
count += 1;
248+
}
249+
});
250+
251+
Ok(count)
252+
}
253+
254+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
255+
where
256+
A: PartialEq,
257+
{
258+
self.count_eq(other).map(|n_eq| self.len() - n_eq)
259+
}
260+
261+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
262+
where
263+
A: AddAssign + Clone + Signed,
264+
{
265+
return_err_if_empty!(self);
266+
return_err_unless_same_shape!(self, other);
267+
268+
let mut result = A::zero();
269+
270+
Zip::from(self).and(other).apply(|self_i, other_i| {
271+
let (a, b) = (self_i.clone(), other_i.clone());
272+
let abs_diff = (a - b).abs();
273+
result += abs_diff.clone() * abs_diff;
274+
});
275+
276+
Ok(result)
277+
}
278+
279+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
280+
where
281+
A: AddAssign + Clone + Signed + ToPrimitive,
282+
{
283+
let sq_l2_dist = self
284+
.sq_l2_dist(other)?
285+
.to_f64()
286+
.expect("failed cast from type A to f64");
287+
288+
Ok(sq_l2_dist.sqrt())
289+
}
290+
291+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
292+
where
293+
A: AddAssign + Clone + Signed,
294+
{
295+
return_err_if_empty!(self);
296+
return_err_unless_same_shape!(self, other);
297+
298+
let mut result = A::zero();
299+
300+
Zip::from(self).and(other).apply(|self_i, other_i| {
301+
let (a, b) = (self_i.clone(), other_i.clone());
302+
result += (a - b).abs();
303+
});
304+
305+
Ok(result)
306+
}
307+
308+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
309+
where
310+
A: Clone + PartialOrd + Signed,
311+
{
312+
return_err_if_empty!(self);
313+
return_err_unless_same_shape!(self, other);
314+
315+
let mut max = A::zero();
316+
317+
Zip::from(self).and(other).apply(|self_i, other_i| {
318+
let (a, b) = (self_i.clone(), other_i.clone());
319+
let diff = (a - b).abs();
320+
if diff > max {
321+
max = diff;
322+
}
323+
});
324+
325+
Ok(max)
326+
}
327+
328+
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
329+
where
330+
A: AddAssign + Clone + Signed + ToPrimitive,
331+
{
332+
let l1_dist = self
333+
.l1_dist(other)?
334+
.to_f64()
335+
.expect("failed cast from type A to f64");
336+
let n = self.len() as f64;
337+
338+
Ok(l1_dist / n)
339+
}
340+
341+
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
342+
where
343+
A: AddAssign + Clone + Signed + ToPrimitive,
344+
{
345+
let sq_l2_dist = self
346+
.sq_l2_dist(other)?
347+
.to_f64()
348+
.expect("failed cast from type A to f64");
349+
let n = self.len() as f64;
350+
351+
Ok(sq_l2_dist / n)
352+
}
353+
354+
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
355+
where
356+
A: AddAssign + Clone + Signed + ToPrimitive,
357+
{
358+
let msd = self.mean_sq_err(other)?;
359+
Ok(msd.sqrt())
360+
}
361+
362+
fn peak_signal_to_noise_ratio(
363+
&self,
364+
other: &ArrayBase<S, D>,
365+
maxv: A,
366+
) -> Result<f64, MultiInputError>
367+
where
368+
A: AddAssign + Clone + Signed + ToPrimitive,
369+
{
370+
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
371+
let msd = self.mean_sq_err(&other)?;
372+
let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
373+
374+
Ok(psnr)
375+
}
376+
377+
private_impl! {}
378+
}

0 commit comments

Comments
 (0)