Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/DataInterpolationsND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,38 @@ the size of `u` along that dimension must match the length of `t` of the corresp
- `u`: The array to be interpolated.
"""
struct NDInterpolation{
N_in, N_out,
ID <: AbstractInterpolationDimension,
N,
N_in,
N_out,
gType <: AbstractInterpolationCache,
D,
uType <: AbstractArray
}
u::uType
interp_dims::NTuple{N_in, ID}
interp_dims::D
cache::gType
function NDInterpolation(u, interp_dims, cache)
if interp_dims isa AbstractInterpolationDimension
interp_dims = (interp_dims,)
end
N_in = length(interp_dims)
N_out = ndims(u) - N_in
function NDInterpolation(u::AbstractArray{<:Any,N}, interp_dims, cache) where N
interp_dims = _add_trailing_interp_dims(interp_dims, Val{N}())
N_in = _count_interpolating_dims(interp_dims)
N_out = _count_noninterpolating_dims(interp_dims)
@assert N_out≥0 "The number of dimensions of u must be at least the number of interpolation dimensions."
validate_size_u(interp_dims, u)
validate_cache(cache, interp_dims, u)
new{N_in, N_out, eltype(interp_dims), typeof(cache), typeof(u)}(
new{N, N_in, N_out, typeof(cache), typeof(interp_dims), typeof(u)}(
u, interp_dims, cache
)
end
end

# TODO probably not type-stable
_count_interpolating_dims(interp_dims) = count(map(d -> !(d isa NoInterpolationDimension), interp_dims))
_count_noninterpolating_dims(interp_dims) = count(map(d -> d isa NoInterpolationDimension, interp_dims))

_add_trailing_interp_dims(dim::AbstractInterpolationDimension, n) =
_add_trailing_interp_dims((dim,), n)
_add_trailing_interp_dims(dims::Tuple, ::Val{N}) where N =
(dims..., ntuple(_ -> NoInterpolationDimension(), Val{N-length(dims)}())...)

# Constructor with optional global cache
function NDInterpolation(u, interp_dims; cache = EmptyCache())
NDInterpolation(u, interp_dims, cache)
Expand All @@ -70,11 +79,11 @@ function (interp::NDInterpolation)(
end

# In place single input evaluation
function (interp::NDInterpolation{N_in})(
out::Union{Number, AbstractArray{<:Number}},
t::Tuple{Vararg{Number, N_in}};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
function (interp::NDInterpolation{N,N_in,N_out})(
out::Union{Number, AbstractArray{<:Number, N_out}},
t::Tuple{Vararg{Number, N}};
derivative_orders::NTuple{N, <:Integer} = ntuple(_ -> 0, N)
) where {N,N_in,N_out}
validate_derivative_orders(derivative_orders, interp)
idx = get_idx(interp.interp_dims, t)
@assert size(out)==size(interp.u)[(N_in + 1):end] "The size of out must match the size of the last N_out dimensions of u."
Expand Down
7 changes: 7 additions & 0 deletions src/interpolation_dimensions.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
NoInterpolationDimension
A dimension that does not perform interpolation.
"""
struct NoInterpolationDimension <: AbstractInterpolationDimension end

"""
LinearInterpolationDimension(t; t_eval = similar(t, 0))
Expand Down
185 changes: 73 additions & 112 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
@@ -1,137 +1,98 @@
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
A::NDInterpolation{N,N_in,N_out},
ts::Tuple{Vararg{Number}},
idx::NTuple{N, <:Integer},
derivative_orders::NTuple{N, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: LinearInterpolationDimension}
out = make_zero!!(out)
any(>(1), derivative_orders) && return out

tᵢ = ntuple(i -> A.interp_dims[i].t[idx[i]], N_in)
tᵢ₊₁ = ntuple(i -> A.interp_dims[i].t[idx[i] + 1], N_in)

# Size of the (hyper)rectangle `t` is in
t_vol = one(eltype(tᵢ))
for (t₁, t₂) in zip(tᵢ, tᵢ₊₁)
t_vol *= t₂ - t₁
) where {N,N_in,N_out}
(; interp_dims, cache, u) = A
check_derivative_order(interp_dims, derivative_orders) || return out
if isnothing(multi_point_index)
multi_point_index = map(_ -> 1, interp_dims)
end
out = make_zero!!(out)
denom = zero(eltype(ts))
# Setup
space = map(iteration_space, interp_dims)
preparations = map(prepare, interp_dims, derivative_orders, multi_point_index, ts, idx)

# Loop over the corners of the (hyper)rectangle `t` is in
for I in Iterators.product(ntuple(i -> (false, true), N_in)...)
c = eltype(out)(inv(t_vol))
for (t_, right_point, d, t₁, t₂) in zip(t, I, derivative_orders, tᵢ, tᵢ₊₁)
c *= if right_point
iszero(d) ? t_ - t₁ : one(t_)
else
iszero(d) ? t₂ - t_ : -one(t_)
end
end
J = (ntuple(i -> idx[i] + I[i], N_in)..., ..)
if iszero(N_out)
out += c * A.u[J...]
for I in Iterators.product(space...)
scaling = map(scale, interp_dims, preparations, I)
J = map(index, interp_dims, ts, idx, I)
product = if cache isa EmptyCache
prod(scaling)
else
@. out += c * A.u[J...]
product = cache.weights[J...] * prod(scaling)
denom += product
end
end
return out
end

function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: ConstantInterpolationDimension}
if any(>(0), derivative_orders)
return if any(i -> !isempty(searchsorted(A.interp_dims[i].t, t[i])), 1:N_in)
typed_nan(out)
if iszero(N_out)
@assert all(map(j -> j isa Integer, J))
out += product * u[J...]
else
out
out .+= product .* view(u, J...)
end
end
idx = ntuple(
i -> t[i] >= A.interp_dims[i].t[end] ? length(A.interp_dims[i].t) : idx[i], N_in)
if iszero(N_out)
out = A.u[idx...]
else
out .= A.u[idx...]
end
return out
end

# BSpline evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims) = A

out = make_zero!!(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
)

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
B_product = prod(dim_in -> basis_function_vals[dim_in][I[dim_in]], 1:N_in)
cp_index = ntuple(
dim_in -> idx[dim_in] + I[dim_in] - degrees[dim_in] - 1, N_in)
if !(cache isa EmptyCache)
if iszero(N_out)
out += B_product * A.u[cp_index...]
out /= denom
else
out .+= B_product * view(A.u, cp_index..., ..)
out ./= denom
end
end

return out
end

# NURBS evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID, <:NURBSWeights},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims, cache) = A
check_derivative_order(dims::Tuple, derivative_orders::Tuple) =
all(map(check_derivative_order, dims, derivative_orders))
check_derivative_order(::LinearInterpolationDimension, d_o) = d_o <= 1
check_derivative_order(::ConstantInterpolationDimension, d_o) = d_0 <= 0
check_derivative_order(::AbstractInterpolationDimension, d_o) = true
# TODO how to handle this
# if derivative_order > 0
# return if any(i -> !isempty(searchsorted(A.interp_dims[i].t, t[i]))
# typed_nan(out)
# else
# out
# end
# end

out = make_zero!!(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
function prepare(d::LinearInterpolationDimension, derivative_order, multi_point_index, t, i)
t₁ = d.t[i]
t₂ = d.t[i + 1]
t_vol_inv = inv(t₂ - t₁)
return (; t, t₁, t₂, t_vol_inv, derivative_order)
end
prepare(::ConstantInterpolationDimension, derivative_orders, multi_point_index, t, i) = (;)
prepare(::NoInterpolationDimension, derivative_orders, multi_point_index, t, i) = (;)
function prepare(d::BSplineInterpolationDimension, derivative_order, multi_point_index, t, i)
# TODO the dim_in arg isn't really needed, so drop it. Currently just 0
basis_function_values = get_basis_function_values(
d, t, i, derivative_order, multi_point_index
)
return (; basis_function_values)
end

denom = zero(eltype(t))

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
B_product = prod(dim_in -> basis_function_vals[dim_in][I[dim_in]], 1:N_in)
cp_index = ntuple(
dim_in -> idx[dim_in] + I[dim_in] - degrees[dim_in] - 1, N_in)
weight = cache.weights[cp_index...]
product = weight * B_product
denom += product
if iszero(N_out)
out += product * A.u[cp_index...]
else
out .+= product * view(A.u, cp_index..., ..)
end
end
iteration_space(::LinearInterpolationDimension) = (false, true)
iteration_space(::ConstantInterpolationDimension) = 1
iteration_space(::NoInterpolationDimension) = 1
iteration_space(d::BSplineInterpolationDimension) = 1:d.degree + 1

if iszero(N_out)
out /= denom
function scale(::LinearInterpolationDimension, prep::NamedTuple, right_point::Bool)
(; t, t₁, t₂, t_vol_inv, derivative_order) = prep
if right_point
iszero(derivative_order) ? t - t₁ : one(t)
else
out ./= denom
end

return out
iszero(derivative_order) ? t₂ - t : -one(t)
end * t_vol_inv
end
scale(::ConstantInterpolationDimension, prep::NamedTuple, i) = 1
scale(::NoInterpolationDimension, prep::NamedTuple, i) = 1
scale(::BSplineInterpolationDimension, prep::NamedTuple, i) = prep.basis_function_values[i]

index(::LinearInterpolationDimension, t, idx, i) = idx + i
index(d::ConstantInterpolationDimension, t, idx, i) = t >= d.t[end] ? length(d.t) : idx[i]
index(::NoInterpolationDimension, t, idx, i) = Colon()
index(d::BSplineInterpolationDimension, t, idx, i) = idx + i - d.degree - 1
51 changes: 27 additions & 24 deletions src/interpolation_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ out of place.

- `derivative_orders`: The partial derivative order for each interpolation dimension. Defaults to `0` for each.
"""
function eval_grid(interp::NDInterpolation{N_in}; kwargs...) where {N_in}
grid_size = map(itp_dim -> length(itp_dim.t_eval), interp.interp_dims)
out = similar(interp.u, (grid_size..., get_output_size(interp)...))
eval_grid!(out, interp; kwargs...)
function eval_grid(interp::NDInterpolation; kwargs...)
sze = map(interp.interp_dims, size(interp.u)) do d, s
d isa NoInterpolationDimension ? s : length(d.t_eval)
end
# TODO: do we need to promote the type here, e.g. for eltype(u) <: Integer ?
out = similar(interp.u, sze)
return eval_grid!(out, interp; kwargs...)
end

"""
Expand All @@ -87,10 +90,11 @@ function eval_grid!(
interp::NDInterpolation{N_in};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
used_interp_dims = _remove(NoInterpolationDimension, interp.interp_dims...)
validate_derivative_orders(derivative_orders, interp; multi_point = true)
backend = get_backend(out)
@assert all(i -> size(out, i) == length(interp.interp_dims[i].t_eval), N_in) "For the first N_in dimensions of out the length must match the t_eval of the corresponding interpolation dimension."
@assert size(out)[(N_in + 1):end]==get_output_size(interp) "The size of the last N_out dimensions of out must be the same as the output size of the interpolation."
@assert all(i -> size(out, i) == length(used_interp_dims[i].t_eval), N_in) "For the first N_in dimensions of out the length must match the t_eval of the corresponding interpolation dimension."
@assert size(out)[(N_in + 1):end] == get_output_size(interp) "The size of the last N_out dimensions of out must be the same as the output size of the interpolation."
eval_kernel(backend)(
out,
interp,
Expand All @@ -104,29 +108,28 @@ end

@kernel function eval_kernel(
out,
@Const(A),
@Const(A::NDInterpolation{N, N_in, N_out}),
derivative_orders,
eval_grid
)
N_in = length(A.interp_dims)
N_out = ndims(A.u) - N_in

eval_grid,
) where {N, N_in, N_out}
k = @index(Global, NTuple)
used_interp_dims = _remove(NoInterpolationDimension, A.interp_dims...)

if eval_grid
t_eval = ntuple(i -> A.interp_dims[i].t_eval[k[i]], N_in)
idx_eval = ntuple(i -> A.interp_dims[i].idx_eval[k[i]], N_in)
else
t_eval = ntuple(i -> A.interp_dims[i].t_eval[only(k)], N_in)
idx_eval = ntuple(i -> A.interp_dims[i].idx_eval[only(k)], N_in)
end
t_eval = ntuple(i -> used_interp_dims[i].t_eval[k[i]], N_in)
idx_eval = ntuple(i -> used_interp_dims[i].idx_eval[k[i]], N_in)

@show N_out
if iszero(N_out)
out[k...] = _interpolate!(
make_out(A, t_eval), A, t_eval, idx_eval, derivative_orders, k)
dest = make_out(A, t_eval)
@show dest t_eval
out[k...] = _interpolate!(dest, A, t_eval, idx_eval, derivative_orders, k)
else
_interpolate!(
view(out, k..., ..),
A, t_eval, idx_eval, derivative_orders, k)
dest = view(out, k..., ..)
_interpolate!(dest, A, t_eval, idx_eval, derivative_orders, k)
end
end

# Remove objects of type T from splatted args (taken from DimensionalData.jl)
Base.@assume_effects :foldable _remove(::Type{T}, x, xs...) where T = (x, _remove(T, xs...)...)
Base.@assume_effects :foldable _remove(::Type{T}, ::T, xs...) where T = _remove(T, xs...)
Base.@assume_effects :foldable _remove(::Type) = ()
Loading
Loading