Skip to content

Commit 70f93de

Browse files
authored
Merge pull request #4 from NREL/clone
Make types Clone-able
2 parents cd7ca8f + 49c77f7 commit 70f93de

16 files changed

+198
-137
lines changed

Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ keywords = [
1515
categories = ["mathematics"]
1616

1717
[dependencies]
18+
dyn-clone = "1.0.19"
1819
itertools = "0.13.0"
19-
ndarray = "0.16.1"
20+
ndarray = ">=0.15, <0.17"
2021
num-traits = "0.2.19"
21-
serde = { version = "1.0.210", optional = true, features = ["derive"] }
22+
serde = { version = "1", optional = true, features = ["derive"] }
2223
thiserror = "1.0.64"
2324

2425
[dev-dependencies]

examples/custom_strategy.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
use ndarray::prelude::*;
2-
1+
use ninterp::data::InterpData2D;
32
use ninterp::prelude::*;
43
use ninterp::strategy::*;
54

5+
// Note: ninterp also re-exposes the internally used `ndarray` crate
6+
// `use ninterp::ndarray;`
7+
use ndarray::prelude::*;
8+
use ndarray::{Data, RawDataClone};
9+
610
// Debug must be derived for custom strategies
7-
#[derive(Debug)]
11+
#[derive(Debug, Clone)]
812
struct CustomStrategy;
913

1014
// Implement strategy for 2-D f32 interpolation
@@ -14,7 +18,7 @@ where
1418
// e.g. `Array2<f32>`, `ArrayView2<f32>`, `CowArray<<'a, f32>, Ix2>`, etc.
1519
// For a more generic bound, consider introducing a bound for D::Elem
1620
// e.g. D::Elem: num_traits::Num + PartialOrd
17-
D: ndarray::Data<Elem = f32>,
21+
D: Data<Elem = f32> + RawDataClone + Clone,
1822
{
1923
fn interpolate(
2024
&self,

examples/dynamic_interpolator.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ fn main() {
2525
)
2626
.unwrap(),
2727
);
28-
assert_eq!(boxed.interpolate(&[1.75]).unwrap(), 8.)
28+
assert_eq!(boxed.interpolate(&[1.75]).unwrap(), 8.);
2929
}

examples/uom.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ fn main() {
1212
let f_x = array![Power::new::<kilowatt>(0.25), Power::new::<kilowatt>(0.75)];
1313
// `uom::si::Quantity` is repr(transparent), meaning it has the same memory layout as its contained type.
1414
// This means we can get the contained type via transmuting.
15-
let interp: Interp1D<ndarray::OwnedRepr<f64>, _> = unsafe {
15+
let interp: ninterp::one::Interp1DOwned<f64, _> = unsafe {
1616
Interp1D::new(
1717
std::mem::transmute(x),
1818
std::mem::transmute(f_x),

src/data.rs

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use super::*;
2+
3+
pub use crate::n::{InterpDataND, InterpDataNDOwned, InterpDataNDViewed};
4+
pub use crate::one::{InterpData1D, InterpData1DOwned, InterpData1DViewed};
5+
pub use crate::three::{InterpData3D, InterpData3DOwned, InterpData3DViewed};
6+
pub use crate::two::{InterpData2D, InterpData2DOwned, InterpData2DViewed};
7+
8+
#[derive(Debug, Clone)]
9+
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
10+
#[cfg_attr(
11+
feature = "serde",
12+
serde(bound = "
13+
D: DataOwned,
14+
D::Elem: Serialize + DeserializeOwned,
15+
Dim<[usize; N]>: Serialize + DeserializeOwned,
16+
[ArrayBase<D, Ix1>; N]: Serialize + DeserializeOwned,
17+
")
18+
)]
19+
pub struct InterpData<D, const N: usize>
20+
where
21+
Dim<[Ix; N]>: Dimension,
22+
D: Data + RawDataClone + Clone,
23+
D::Elem: Num + PartialOrd + Copy + Debug,
24+
{
25+
pub grid: [ArrayBase<D, Ix1>; N],
26+
pub values: ArrayBase<D, Dim<[Ix; N]>>,
27+
}
28+
pub type InterpDataViewed<T, const N: usize> = InterpData<ndarray::ViewRepr<T>, N>;
29+
pub type InterpDataOwned<T, const N: usize> = InterpData<ndarray::ViewRepr<T>, N>;
30+
31+
impl<D, const N: usize> InterpData<D, N>
32+
where
33+
Dim<[Ix; N]>: Dimension,
34+
D: Data + RawDataClone + Clone,
35+
D::Elem: Num + PartialOrd + Copy + Debug,
36+
{
37+
pub fn validate(&self) -> Result<(), ValidateError> {
38+
for i in 0..N {
39+
let i_grid_len = self.grid[i].len();
40+
// Check that each grid dimension has elements
41+
// Indexing `grid` directly is okay because empty dimensions are caught at compilation
42+
if i_grid_len == 0 {
43+
return Err(ValidateError::EmptyGrid(i));
44+
}
45+
// Check that grid points are monotonically increasing
46+
if !self.grid[i].windows(2).into_iter().all(|w| w[0] <= w[1]) {
47+
return Err(ValidateError::Monotonicity(i));
48+
}
49+
// Check that grid and values are compatible shapes
50+
if i_grid_len != self.values.shape()[i] {
51+
return Err(ValidateError::IncompatibleShapes(i));
52+
}
53+
}
54+
Ok(())
55+
}
56+
}

src/lib.rs

+15-56
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ pub mod prelude {
123123
pub use crate::Interpolator;
124124
}
125125

126+
pub mod data;
126127
pub mod error;
127128
pub mod strategy;
128129

@@ -133,23 +134,27 @@ pub mod two;
133134
pub mod zero;
134135

135136
pub mod interpolator {
136-
pub use crate::n::InterpND;
137-
pub use crate::one::Interp1D;
138-
pub use crate::three::Interp3D;
139-
pub use crate::two::Interp2D;
137+
pub use crate::n::{InterpND, InterpNDOwned, InterpNDViewed};
138+
pub use crate::one::{Interp1D, Interp1DOwned, Interp1DViewed};
139+
pub use crate::three::{Interp3D, Interp3DOwned, Interp3DViewed};
140+
pub use crate::two::{Interp2D, Interp2DOwned, Interp2DViewed};
140141
pub use crate::zero::Interp0D;
141142
}
142143

144+
pub(crate) use data::*;
143145
pub(crate) use error::*;
144146
pub(crate) use strategy::*;
145147

146148
pub(crate) use std::fmt::Debug;
147149

150+
pub use ndarray;
148151
pub(crate) use ndarray::prelude::*;
149-
pub(crate) use ndarray::{Data, Ix};
152+
pub(crate) use ndarray::{Data, Ix, RawDataClone};
150153

151154
pub(crate) use num_traits::{clamp, Num, One};
152155

156+
pub(crate) use dyn_clone::*;
157+
153158
#[cfg(feature = "serde")]
154159
pub(crate) use ndarray::DataOwned;
155160
#[cfg(feature = "serde")]
@@ -173,7 +178,7 @@ pub(crate) use assert_approx_eq;
173178
/// This trait is dyn-compatible, meaning you can use:
174179
/// `Box<dyn Interpolator<_>>`
175180
/// and swap the contained interpolator at runtime.
176-
pub trait Interpolator<T> {
181+
pub trait Interpolator<T>: DynClone {
177182
/// Interpolator dimensionality.
178183
fn ndim(&self) -> usize;
179184
/// Validate interpolator data.
@@ -182,6 +187,8 @@ pub trait Interpolator<T> {
182187
fn interpolate(&self, point: &[T]) -> Result<T, InterpolateError>;
183188
}
184189

190+
clone_trait_object!(<T> Interpolator<T>);
191+
185192
impl<T> Interpolator<T> for Box<dyn Interpolator<T>> {
186193
fn ndim(&self) -> usize {
187194
(**self).ndim()
@@ -194,54 +201,6 @@ impl<T> Interpolator<T> for Box<dyn Interpolator<T>> {
194201
}
195202
}
196203

197-
#[derive(Debug)]
198-
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
199-
#[cfg_attr(
200-
feature = "serde",
201-
serde(bound = "
202-
D: DataOwned,
203-
D::Elem: Serialize + DeserializeOwned,
204-
Dim<[usize; N]>: Serialize + DeserializeOwned,
205-
[ArrayBase<D, Ix1>; N]: Serialize + DeserializeOwned,
206-
")
207-
)]
208-
pub struct InterpData<D, const N: usize>
209-
where
210-
Dim<[Ix; N]>: Dimension,
211-
D: Data,
212-
D::Elem: Num + PartialOrd + Copy + Debug,
213-
{
214-
pub grid: [ArrayBase<D, Ix1>; N],
215-
pub values: ArrayBase<D, Dim<[Ix; N]>>,
216-
}
217-
218-
impl<D, const N: usize> InterpData<D, N>
219-
where
220-
Dim<[Ix; N]>: Dimension,
221-
D: Data,
222-
D::Elem: Num + PartialOrd + Copy + Debug,
223-
{
224-
pub fn validate(&self) -> Result<(), ValidateError> {
225-
for i in 0..N {
226-
let i_grid_len = self.grid[i].len();
227-
// Check that each grid dimension has elements
228-
// Indexing `grid` directly is okay because empty dimensions are caught at compilation
229-
if i_grid_len == 0 {
230-
return Err(ValidateError::EmptyGrid(i));
231-
}
232-
// Check that grid points are monotonically increasing
233-
if !self.grid[i].windows(2).into_iter().all(|w| w[0] <= w[1]) {
234-
return Err(ValidateError::Monotonicity(i));
235-
}
236-
// Check that grid and values are compatible shapes
237-
if i_grid_len != self.values.shape()[i] {
238-
return Err(ValidateError::IncompatibleShapes(i));
239-
}
240-
}
241-
Ok(())
242-
}
243-
}
244-
245204
/// Extrapolation strategy
246205
///
247206
/// Controls what happens if supplied interpolant point
@@ -264,9 +223,9 @@ macro_rules! extrapolate_impl {
264223
($InterpType:ident, $Strategy:ident) => {
265224
impl<D, S> $InterpType<D, S>
266225
where
267-
D: Data,
226+
D: Data + RawDataClone + Clone,
268227
D::Elem: Num + PartialOrd + Copy + Debug,
269-
S: $Strategy<D>,
228+
S: $Strategy<D> + Clone,
270229
{
271230
/// Set [`Extrapolate`] variant, checking validity.
272231
pub fn set_extrapolate(

src/n/mod.rs

+20-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use ndarray::prelude::*;
66

77
mod strategies;
88
/// Interpolator data where N is determined at runtime
9-
#[derive(Debug)]
9+
#[derive(Debug, Clone)]
1010
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
1111
#[cfg_attr(
1212
feature = "serde",
@@ -17,15 +17,20 @@ mod strategies;
1717
)]
1818
pub struct InterpDataND<D>
1919
where
20-
D: Data,
20+
D: Data + RawDataClone + Clone,
2121
D::Elem: Num + PartialOrd + Copy + Debug,
2222
{
2323
pub grid: Vec<ArrayBase<D, Ix1>>,
2424
pub values: ArrayBase<D, IxDyn>,
2525
}
26+
/// [`InterpDataND`] that views data.
27+
pub type InterpDataNDViewed<T> = InterpDataND<ndarray::ViewRepr<T>>;
28+
/// [`InterpDataND`] that owns data.
29+
pub type InterpDataNDOwned<T> = InterpDataND<ndarray::OwnedRepr<T>>;
30+
2631
impl<D> InterpDataND<D>
2732
where
28-
D: Data,
33+
D: Data + RawDataClone + Clone,
2934
D::Elem: Num + PartialOrd + Copy + Debug,
3035
{
3136
pub fn ndim(&self) -> usize {
@@ -76,7 +81,7 @@ where
7681
}
7782

7883
/// N-D interpolator
79-
#[derive(Debug)]
84+
#[derive(Debug, Clone)]
8085
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
8186
#[cfg_attr(
8287
feature = "serde",
@@ -88,9 +93,9 @@ where
8893
)]
8994
pub struct InterpND<D, S>
9095
where
91-
D: Data,
96+
D: Data + RawDataClone + Clone,
9297
D::Elem: Num + PartialOrd + Copy + Debug,
93-
S: StrategyND<D>,
98+
S: StrategyND<D> + Clone,
9499
{
95100
pub data: InterpDataND<D>,
96101
pub strategy: S,
@@ -101,14 +106,18 @@ where
101106
)]
102107
pub extrapolate: Extrapolate<D::Elem>,
103108
}
109+
/// [`InterpND`] that views data.
110+
pub type InterpNDViewed<T, S> = InterpND<ndarray::ViewRepr<T>, S>;
111+
/// [`InterpND`] that owns data.
112+
pub type InterpNDOwned<T, S> = InterpND<ndarray::OwnedRepr<T>, S>;
104113

105114
extrapolate_impl!(InterpND, StrategyND);
106115

107116
impl<D, S> InterpND<D, S>
108117
where
109-
D: Data,
118+
D: Data + RawDataClone + Clone,
110119
D::Elem: Num + PartialOrd + Copy + Debug,
111-
S: StrategyND<D>,
120+
S: StrategyND<D> + Clone,
112121
{
113122
/// Instantiate N-dimensional (any dimensionality) interpolator.
114123
///
@@ -170,9 +179,9 @@ where
170179

171180
impl<D, S> Interpolator<D::Elem> for InterpND<D, S>
172181
where
173-
D: Data,
182+
D: Data + RawDataClone + Clone,
174183
D::Elem: Num + PartialOrd + Copy + Debug,
175-
S: StrategyND<D>,
184+
S: StrategyND<D> + Clone,
176185
{
177186
fn ndim(&self) -> usize {
178187
self.data.ndim()
@@ -229,7 +238,7 @@ where
229238

230239
impl<D> InterpND<D, Box<dyn StrategyND<D>>>
231240
where
232-
D: Data,
241+
D: Data + RawDataClone + Clone,
233242
D::Elem: Num + PartialOrd + Copy + Debug,
234243
{
235244
/// Update strategy dynamically.

src/n/strategies.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use super::*;
22

33
use itertools::Itertools;
4-
// TODO: any way to remove `RawDataClone`?
5-
use ndarray::RawDataClone;
64

75
pub fn get_index_permutations(shape: &[usize]) -> Vec<Vec<usize>> {
86
if shape.is_empty() {
@@ -17,8 +15,7 @@ pub fn get_index_permutations(shape: &[usize]) -> Vec<Vec<usize>> {
1715

1816
impl<D> StrategyND<D> for Linear
1917
where
20-
// TODO: any way to remove the `RawDataClone` bound?
21-
D: Data + RawDataClone,
18+
D: Data + RawDataClone + Clone,
2219
D::Elem: Num + PartialOrd + Copy + Debug,
2320
{
2421
fn interpolate(
@@ -120,8 +117,7 @@ where
120117

121118
impl<D> StrategyND<D> for Nearest
122119
where
123-
// TODO: any way to remove the `RawDataClone` bound?
124-
D: Data + RawDataClone,
120+
D: Data + RawDataClone + Clone,
125121
D::Elem: Num + PartialOrd + Copy + Debug,
126122
{
127123
fn interpolate(

0 commit comments

Comments
 (0)