Skip to content

Commit 673fe35

Browse files
committed
Use SCT extension
1 parent 538e649 commit 673fe35

File tree

4 files changed

+57
-56
lines changed

4 files changed

+57
-56
lines changed

ext/DataInterpolationsSparseConnectivityTracerExt.jl

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,24 @@ module DataInterpolationsSparseConnectivityTracerExt
33
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
44
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
55
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
6-
using FillArrays: Fill # from FillArrays.jl
6+
using FillArrays: Fill
77
using DataInterpolations:
8-
AbstractInterpolation,
9-
LinearInterpolation,
10-
QuadraticInterpolation,
11-
LagrangeInterpolation,
12-
AkimaInterpolation,
13-
ConstantInterpolation,
14-
QuadraticSpline,
15-
CubicSpline,
16-
BSplineInterpolation,
17-
BSplineApprox,
18-
CubicHermiteSpline,
19-
# PCHIPInterpolation,
20-
QuinticHermiteSpline,
21-
output_size
8+
AbstractInterpolation,
9+
AkimaInterpolation,
10+
BSplineApprox,
11+
BSplineInterpolation,
12+
ConstantInterpolation,
13+
CubicHermiteSpline,
14+
CubicSpline,
15+
LagrangeInterpolation,
16+
LinearInterpolation,
17+
QuadraticInterpolation,
18+
QuadraticSpline,
19+
QuinticHermiteSpline,
20+
SmoothedLinearInterpolation,
21+
output_size
2222

23-
#===========#
2423
# Utilities #
25-
#===========#
2624

2725
# Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number},
2826
# to avoid any cases where the output size is dependent on the input value.
@@ -33,26 +31,26 @@ function _sct_interpolate(
3331
uType::Type{<:AbstractVector{<:Number}},
3432
t::GradientTracer,
3533
is_der_1_zero,
36-
is_der_2_zero,
37-
)
34+
is_der_2_zero
35+
)
3836
return gradient_tracer_1_to_1(t, is_der_1_zero)
3937
end
4038
function _sct_interpolate(
4139
::AbstractInterpolation,
4240
uType::Type{<:AbstractVector{<:Number}},
4341
t::HessianTracer,
4442
is_der_1_zero,
45-
is_der_2_zero,
46-
)
43+
is_der_2_zero
44+
)
4745
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
4846
end
4947
function _sct_interpolate(
5048
interp::AbstractInterpolation,
5149
uType::Type{<:AbstractMatrix{<:Number}},
5250
t::GradientTracer,
5351
is_der_1_zero,
54-
is_der_2_zero,
55-
)
52+
is_der_2_zero
53+
)
5654
t = gradient_tracer_1_to_1(t, is_der_1_zero)
5755
N = only(output_size(interp))
5856
return Fill(t, N)
@@ -62,48 +60,47 @@ function _sct_interpolate(
6260
uType::Type{<:AbstractMatrix{<:Number}},
6361
t::HessianTracer,
6462
is_der_1_zero,
65-
is_der_2_zero,
66-
)
63+
is_der_2_zero
64+
)
6765
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
6866
N = only(output_size(interp))
6967
return Fill(t, N)
70-
end
68+
end #===========#
7169

72-
#===========#
7370
# Overloads #
74-
#===========#
7571

7672
# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
7773
# all interpolations have a non-zero second derivative at some point in the input domain.
7874

7975
for (I, is_der1_zero, is_der2_zero) in (
80-
(:ConstantInterpolation, true, true),
81-
(:LinearInterpolation, false, true),
82-
(:QuadraticInterpolation, false, false),
83-
(:LagrangeInterpolation, false, false),
84-
(:AkimaInterpolation, false, false),
85-
(:QuadraticSpline, false, false),
86-
(:CubicSpline, false, false),
87-
(:BSplineInterpolation, false, false),
88-
(:BSplineApprox, false, false),
89-
(:CubicHermiteSpline, false, false),
90-
(:QuinticHermiteSpline, false, false),
91-
)
76+
(:AkimaInterpolation, false, false),
77+
(:BSplineApprox, false, false),
78+
(:BSplineInterpolation, false, false),
79+
(:ConstantInterpolation, true, true),
80+
(:CubicHermiteSpline, false, false),
81+
(:CubicSpline, false, false),
82+
(:LagrangeInterpolation, false, false),
83+
(:LinearInterpolation, false, true),
84+
(:QuadraticInterpolation, false, false),
85+
(:QuadraticSpline, false, false),
86+
(:QuinticHermiteSpline, false, false),
87+
(:SmoothedLinearInterpolation, false, false)
88+
)
9289
@eval function (interp::$(I){uType})(
9390
t::AbstractTracer
94-
) where {uType <: AbstractArray{<:Number}}
91+
) where {uType <: AbstractArray{<:Number}}
9592
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
9693
end
9794
end
9895

9996
# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
10097
for I in (
101-
:LagrangeInterpolation,
102-
:BSplineInterpolation,
103-
:BSplineApprox,
104-
:CubicHermiteSpline,
105-
:QuinticHermiteSpline,
106-
)
98+
:LagrangeInterpolation,
99+
:BSplineInterpolation,
100+
:BSplineApprox,
101+
:CubicHermiteSpline,
102+
:QuinticHermiteSpline
103+
)
107104
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector}
108105
p = interp(primal(d))
109106
t = interp(tracer(d))

src/interpolation_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ function cumulative_integral(A::AbstractInterpolation{<:Number}, cache_parameter
191191
Base.require_one_based_indexing(A.u)
192192
idxs = cache_parameters ? (1:(length(A.t) - 1)) : (1:0)
193193
return cumsum(_integral(A, idx, t1, t2)
194-
for (idx, t1, t2) in zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
194+
for (idx, t1, t2) in
195+
zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
195196
end
196197

197198
function get_parameters(A::LinearInterpolation, idx)

test/derivative_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ function test_derivatives(method; args = [], kwargs = [], name::String)
3535

3636
# Interpolation transition points
3737
for _t in t[2:(end - 1)]
38-
if func isa Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
38+
if func isa
39+
Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
3940
# TODO fix interpolations
4041
continue
4142
else

test/sparseconnectivitytracer_tests.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,22 @@ end
232232

233233
@testset "1D Interpolations" begin
234234
@testset "$(testname(t))" for t in (
235+
236+
InterpolationTest(AkimaInterpolation(u, t)),
237+
InterpolationTest(BSplineApprox(u, t, 3, 4, :ArcLen, :Average)),
238+
InterpolationTest(BSplineInterpolation(u, t, 3, :ArcLen, :Average)),
235239
InterpolationTest(
236240
ConstantInterpolation(u, t); is_der1_zero = true, is_der2_zero = true
237241
),
242+
InterpolationTest(CubicHermiteSpline(du, u, t)),
243+
InterpolationTest(CubicSpline(u, t)),
244+
InterpolationTest(LagrangeInterpolation(u, t)),
238245
InterpolationTest(LinearInterpolation(u, t); is_der2_zero = true),
246+
InterpolationTest(PCHIPInterpolation(u, t)),
239247
InterpolationTest(QuadraticInterpolation(u, t)),
240-
InterpolationTest(LagrangeInterpolation(u, t)),
241-
InterpolationTest(AkimaInterpolation(u, t)),
242248
InterpolationTest(QuadraticSpline(u, t)),
243-
InterpolationTest(CubicSpline(u, t)),
244-
InterpolationTest(BSplineInterpolation(u, t, 3, :ArcLen, :Average)),
245-
InterpolationTest(BSplineApprox(u, t, 3, 4, :ArcLen, :Average)),
246-
InterpolationTest(PCHIPInterpolation(u, t)),
247-
InterpolationTest(CubicHermiteSpline(du, u, t)),
248249
InterpolationTest(QuinticHermiteSpline(ddu, du, u, t)),
250+
InterpolationTest(SmoothedLinearInterpolation(u, t))
249251
)
250252
test_jacobian(t)
251253
test_hessian(t)

0 commit comments

Comments
 (0)