Skip to content

Commit f3930b5

Browse files
committed
test/bench cleanup
1 parent a3b5fa0 commit f3930b5

File tree

7 files changed

+109
-212
lines changed

7 files changed

+109
-212
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ thiserror = "1.0.64"
2424
[dev-dependencies]
2525
criterion = "0.5.1"
2626
ndarray-rand = "0.15.0"
27+
approx = "0.5.1"
2728

2829
[[bench]]
2930
name = "benchmark"

benches/benchmark.rs

+19-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Benchmarks for 0/1/2/3/N-dimensional linear interpolation
22
//! Run these with `cargo bench`
33
4-
use criterion::{criterion_group, criterion_main, Criterion};
4+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
55

66
use ndarray::prelude::*;
77
use ninterp::prelude::*;
@@ -29,7 +29,7 @@ fn benchmark_0D_multi() {
2929
Extrapolate::Error,
3030
)
3131
.unwrap();
32-
interp_0d_multi.interpolate(&[]).unwrap();
32+
interp_0d_multi.interpolate(black_box(&[])).unwrap();
3333
}
3434

3535
#[allow(non_snake_case)]
@@ -44,7 +44,7 @@ fn benchmark_1D() {
4444
// Sample 1,000 points
4545
let points: Vec<f64> = (0..1_000).map(|_| rng.gen::<f64>() * 99.).collect();
4646
for point in points {
47-
interp_1d.interpolate(&[point]).unwrap();
47+
interp_1d.interpolate(black_box(&[point])).unwrap();
4848
}
4949
}
5050

@@ -61,7 +61,7 @@ fn benchmark_1D_multi() {
6161
// Sample 1,000 points
6262
let points: Vec<f64> = (0..1_000).map(|_| rng.gen::<f64>() * 99.).collect();
6363
for point in points {
64-
interp_1d_multi.interpolate(&[point]).unwrap();
64+
interp_1d_multi.interpolate(black_box(&[point])).unwrap();
6565
}
6666
}
6767

@@ -73,9 +73,9 @@ fn benchmark_2D() {
7373
let values_data = Array2::random_using((100, 100), Uniform::new(0., 1.), &mut rng);
7474
// Create a 2-D interpolator with 100x100 data (10,000 points)
7575
let interp_2d = Interp2D::new(
76-
grid_data.clone(),
77-
grid_data.clone(),
78-
values_data,
76+
grid_data.view(),
77+
grid_data.view(),
78+
values_data.view(),
7979
Linear,
8080
Extrapolate::Error,
8181
)
@@ -85,7 +85,7 @@ fn benchmark_2D() {
8585
.map(|_| vec![rng.gen::<f64>() * 99., rng.gen::<f64>() * 99.])
8686
.collect();
8787
for point in points {
88-
interp_2d.interpolate(&point).unwrap();
88+
interp_2d.interpolate(black_box(&point)).unwrap();
8989
}
9090
}
9191

@@ -98,8 +98,8 @@ fn benchmark_2D_multi() {
9898
let values_data = Array2::random_using((100, 100), Uniform::new(0., 1.), &mut rng).into_dyn();
9999
// Create an N-D interpolator with 100x100 data (10,000 points)
100100
let interp_2d_multi = InterpND::new(
101-
vec![grid_data.clone(), grid_data.clone()],
102-
values_data,
101+
vec![grid_data.view(), grid_data.view()],
102+
values_data.view(),
103103
Linear,
104104
Extrapolate::Error,
105105
)
@@ -109,7 +109,7 @@ fn benchmark_2D_multi() {
109109
.map(|_| vec![rng.gen::<f64>() * 99., rng.gen::<f64>() * 99.])
110110
.collect();
111111
for point in points {
112-
interp_2d_multi.interpolate(&point).unwrap();
112+
interp_2d_multi.interpolate(black_box(&point)).unwrap();
113113
}
114114
}
115115

@@ -122,10 +122,10 @@ fn benchmark_3D() {
122122
let values_data = Array3::random_using((100, 100, 100), Uniform::new(0., 1.), &mut rng);
123123
// Create a 3-D interpolator with 100x100x100 data (1,000,000 points)
124124
let interp_3d = Interp3D::new(
125-
grid_data.clone(),
126-
grid_data.clone(),
127-
grid_data.clone(),
128-
values_data,
125+
grid_data.view(),
126+
grid_data.view(),
127+
grid_data.view(),
128+
values_data.view(),
129129
Linear,
130130
Extrapolate::Error,
131131
)
@@ -141,7 +141,7 @@ fn benchmark_3D() {
141141
})
142142
.collect();
143143
for point in points {
144-
interp_3d.interpolate(&point).unwrap();
144+
interp_3d.interpolate(black_box(&point)).unwrap();
145145
}
146146
}
147147

@@ -155,8 +155,8 @@ fn benchmark_3D_multi() {
155155
Array3::random_using((100, 100, 100), Uniform::new(0., 1.), &mut rng).into_dyn();
156156
// Create an N-D interpolator with 100x100x100 data (1,000,000 points)
157157
let interp_3d_multi = InterpND::new(
158-
vec![grid_data.clone(), grid_data.clone(), grid_data.clone()],
159-
values_data,
158+
vec![grid_data.view(), grid_data.view(), grid_data.view()],
159+
values_data.view(),
160160
Linear,
161161
Extrapolate::Error,
162162
)
@@ -172,7 +172,7 @@ fn benchmark_3D_multi() {
172172
})
173173
.collect();
174174
for point in points {
175-
interp_3d_multi.interpolate(&point).unwrap();
175+
interp_3d_multi.interpolate(black_box(&point)).unwrap();
176176
}
177177
}
178178

src/lib.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,19 @@ pub(crate) use ndarray::DataOwned;
148148
#[cfg(feature = "serde")]
149149
pub(crate) use serde::{de::DeserializeOwned, Deserialize, Serialize};
150150

151+
#[cfg(test)]
152+
/// Alias for [`approx::assert_abs_diff_eq`] with `epsilon = 1e-6`
153+
macro_rules! assert_approx_eq {
154+
($a:expr, $b:expr $(,)?) => {
155+
approx::assert_abs_diff_eq!($a, $b, epsilon = 1e-6)
156+
};
157+
($a:expr, $b:expr, $eps:expr $(,)?) => {
158+
approx::assert_abs_diff_eq!($a, $b, epsilon = $eps)
159+
};
160+
}
161+
#[cfg(test)]
162+
pub(crate) use assert_approx_eq;
163+
151164
/// An interpolator of data type `T`
152165
///
153166
/// This trait is dyn-compatible, meaning you can use:
@@ -214,7 +227,6 @@ where
214227
/// is outside the bounds of the interpolation grid.
215228
#[derive(Clone, Copy, Debug, PartialEq, Default)]
216229
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
217-
// #[cfg_attr(feature = "serde", serde(bound = "T: Serialize + DeserializeOwned"))]
218230
pub enum Extrapolate<T> {
219231
/// Evaluate beyond the limits of the interpolation grid.
220232
Enable,

src/n/mod.rs

+25-52
Original file line numberDiff line numberDiff line change
@@ -244,58 +244,34 @@ mod tests {
244244

245245
#[test]
246246
fn test_linear() {
247-
let grid = vec![
248-
array![0.05, 0.10, 0.15],
249-
array![0.10, 0.20, 0.30],
250-
array![0.20, 0.40, 0.60],
251-
];
247+
let x = array![0.05, 0.10, 0.15];
248+
let y = array![0.10, 0.20, 0.30];
249+
let z = array![0.20, 0.40, 0.60];
250+
let grid = vec![x.view(), y.view(), z.view()];
252251
let values = array![
253252
[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]],
254253
[[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]],
255254
[[18., 19., 20.], [21., 22., 23.], [24., 25., 26.]],
256255
]
257256
.into_dyn();
258-
let interp =
259-
InterpND::new(grid.clone(), values.clone(), Linear, Extrapolate::Error).unwrap();
257+
let interp = InterpND::new(grid, values.view(), Linear, Extrapolate::Error).unwrap();
260258
// Check that interpolating at grid points just retrieves the value
261-
for i in 0..grid[0].len() {
262-
for j in 0..grid[1].len() {
263-
for k in 0..grid[2].len() {
259+
for i in 0..x.len() {
260+
for j in 0..y.len() {
261+
for k in 0..z.len() {
264262
assert_eq!(
265-
&interp
266-
.interpolate(&[grid[0][i], grid[1][j], grid[2][k]])
267-
.unwrap(),
263+
&interp.interpolate(&[x[i], y[j], z[k]]).unwrap(),
268264
values.slice(s![i, j, k]).first().unwrap()
269265
);
270266
}
271267
}
272268
}
273-
assert_eq!(
274-
interp.interpolate(&[grid[0][0], grid[1][0], 0.3]).unwrap(),
275-
0.4999999999999999 // 0.5
276-
);
277-
assert_eq!(
278-
interp.interpolate(&[grid[0][0], 0.15, grid[2][0]]).unwrap(),
279-
1.4999999999999996 // 1.5
280-
);
281-
assert_eq!(
282-
interp.interpolate(&[grid[0][0], 0.15, 0.3]).unwrap(),
283-
1.9999999999999996 // 2.0
284-
);
285-
assert_eq!(
286-
interp
287-
.interpolate(&[0.075, grid[1][0], grid[2][0]])
288-
.unwrap(),
289-
4.499999999999999 // 4.5
290-
);
291-
assert_eq!(
292-
interp.interpolate(&[0.075, grid[1][0], 0.3]).unwrap(),
293-
4.999999999999999 // 5.0
294-
);
295-
assert_eq!(
296-
interp.interpolate(&[0.075, 0.15, grid[2][0]]).unwrap(),
297-
5.999999999999998 // 6.0
298-
);
269+
assert_approx_eq!(interp.interpolate(&[x[0], y[0], 0.3]).unwrap(), 0.5);
270+
assert_approx_eq!(interp.interpolate(&[x[0], 0.15, z[0]]).unwrap(), 1.5);
271+
assert_approx_eq!(interp.interpolate(&[x[0], 0.15, 0.3]).unwrap(), 2.0);
272+
assert_approx_eq!(interp.interpolate(&[0.075, y[0], z[0]]).unwrap(), 4.5);
273+
assert_approx_eq!(interp.interpolate(&[0.075, y[0], 0.3]).unwrap(), 5.);
274+
assert_approx_eq!(interp.interpolate(&[0.075, 0.15, z[0]]).unwrap(), 6.);
299275
}
300276

301277
#[test]
@@ -307,10 +283,7 @@ mod tests {
307283
Extrapolate::Error,
308284
)
309285
.unwrap();
310-
assert_eq!(
311-
interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(),
312-
3.1999999999999997
313-
) // 3.2
286+
assert_approx_eq!(interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(), 3.2)
314287
}
315288

316289
#[test]
@@ -475,18 +448,18 @@ mod tests {
475448

476449
#[test]
477450
fn test_nearest() {
478-
let grid = vec![array![0., 1.], array![0., 1.], array![0., 1.]];
451+
let x = array![0., 1.];
452+
let y = array![0., 1.];
453+
let z = array![0., 1.];
454+
let grid = vec![x.view(), y.view(), z.view()];
479455
let values = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn();
480-
let interp =
481-
InterpND::new(grid.clone(), values.clone(), Nearest, Extrapolate::Error).unwrap();
456+
let interp = InterpND::new(grid, values.view(), Nearest, Extrapolate::Error).unwrap();
482457
// Check that interpolating at grid points just retrieves the value
483-
for i in 0..grid[0].len() {
484-
for j in 0..grid[1].len() {
485-
for k in 0..grid[2].len() {
458+
for i in 0..x.len() {
459+
for j in 0..y.len() {
460+
for k in 0..z.len() {
486461
assert_eq!(
487-
&interp
488-
.interpolate(&[grid[0][i], grid[1][j], grid[2][k]])
489-
.unwrap(),
462+
&interp.interpolate(&[x[i], y[j], z[k]]).unwrap(),
490463
values.slice(s![i, j, k]).first().unwrap()
491464
);
492465
}

src/one/mod.rs

+5-10
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ mod tests {
216216
fn test_linear() {
217217
let x = array![0., 1., 2., 3., 4.];
218218
let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0];
219-
let interp = Interp1D::new(x.clone(), f_x.clone(), Linear, Extrapolate::Error).unwrap();
219+
let interp = Interp1D::new(x.view(), f_x.view(), Linear, Extrapolate::Error).unwrap();
220220
// Check that interpolating at grid points just retrieves the value
221221
for (i, x_i) in x.iter().enumerate() {
222222
assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]);
@@ -230,8 +230,7 @@ mod tests {
230230
fn test_left_nearest() {
231231
let x = array![0., 1., 2., 3., 4.];
232232
let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0];
233-
let interp =
234-
Interp1D::new(x.clone(), f_x.clone(), LeftNearest, Extrapolate::Error).unwrap();
233+
let interp = Interp1D::new(x.view(), f_x.view(), LeftNearest, Extrapolate::Error).unwrap();
235234
// Check that interpolating at grid points just retrieves the value
236235
for (i, x_i) in x.iter().enumerate() {
237236
assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]);
@@ -245,8 +244,7 @@ mod tests {
245244
fn test_right_nearest() {
246245
let x = array![0., 1., 2., 3., 4.];
247246
let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0];
248-
let interp =
249-
Interp1D::new(x.clone(), f_x.clone(), RightNearest, Extrapolate::Error).unwrap();
247+
let interp = Interp1D::new(x.view(), f_x.view(), RightNearest, Extrapolate::Error).unwrap();
250248
// Check that interpolating at grid points just retrieves the value
251249
for (i, x_i) in x.iter().enumerate() {
252250
assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]);
@@ -260,7 +258,7 @@ mod tests {
260258
fn test_nearest() {
261259
let x = array![0., 1., 2., 3., 4.];
262260
let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0];
263-
let interp = Interp1D::new(x.clone(), f_x.clone(), Nearest, Extrapolate::Error).unwrap();
261+
let interp = Interp1D::new(x.view(), f_x.view(), Nearest, Extrapolate::Error).unwrap();
264262
// Check that interpolating at grid points just retrieves the value
265263
for (i, x_i) in x.iter().enumerate() {
266264
assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]);
@@ -344,10 +342,7 @@ mod tests {
344342
)
345343
.unwrap();
346344
assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.0);
347-
assert_eq!(
348-
interp.interpolate(&[-0.75]).unwrap(),
349-
0.04999999999999999 // 0.05
350-
);
345+
assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05);
351346
assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2);
352347
}
353348
}

0 commit comments

Comments
 (0)