Skip to content

Commit 538e649

Browse files
committed
Merge remote-tracking branch 'upstream/master' into add_smoothed_linear_interpolation
2 parents 806945e + c181596 commit 538e649

File tree

6 files changed

+423
-16
lines changed

6 files changed

+423
-16
lines changed

Project.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DataInterpolations"
22
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
3-
version = "8.1.0"
3+
version = "8.3.1"
44

55
[deps]
66
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
@@ -13,25 +13,29 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1313

1414
[weakdeps]
1515
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1617
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1718
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
19+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1820
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1921
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2022

2123
[extensions]
2224
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
2325
DataInterpolationsOptimExt = "Optim"
2426
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
27+
DataInterpolationsSparseConnectivityTracerExt = ["SparseConnectivityTracer", "FillArrays"]
2528
DataInterpolationsSymbolicsExt = "Symbolics"
2629

2730
[compat]
2831
Aqua = "0.8"
2932
BenchmarkTools = "1"
3033
ChainRulesCore = "1.24"
3134
EnumX = "1.0.4"
35+
FillArrays = "1.13.0"
3236
FindFirstFunctions = "1.3"
3337
FiniteDifferences = "0.12.31"
34-
ForwardDiff = "0.10.36"
38+
ForwardDiff = "0.10.36, 1"
3539
LinearAlgebra = "1.10"
3640
Optim = "1.6"
3741
PrettyTables = "2"
@@ -40,6 +44,7 @@ RecipesBase = "1.3"
4044
Reexport = "1"
4145
RegularizationTools = "0.6"
4246
SafeTestsets = "0.1"
47+
SparseConnectivityTracer = "1"
4348
StableRNGs = "1"
4449
Symbolics = "5.29, 6"
4550
Test = "1.10"
@@ -57,11 +62,12 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5762
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
5863
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
5964
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
65+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
6066
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
6167
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
6268
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6369
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6470
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6571

6672
[targets]
67-
test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"]
73+
test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote", "SparseConnectivityTracer"]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
module DataInterpolationsSparseConnectivityTracerExt
2+
3+
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
4+
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
5+
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
6+
using FillArrays: Fill # from FillArrays.jl
7+
using DataInterpolations:
8+
AbstractInterpolation,
9+
LinearInterpolation,
10+
QuadraticInterpolation,
11+
LagrangeInterpolation,
12+
AkimaInterpolation,
13+
ConstantInterpolation,
14+
QuadraticSpline,
15+
CubicSpline,
16+
BSplineInterpolation,
17+
BSplineApprox,
18+
CubicHermiteSpline,
19+
# PCHIPInterpolation,
20+
QuinticHermiteSpline,
21+
output_size
22+
23+
#===========#
24+
# Utilities #
25+
#===========#
26+
27+
# Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number},
28+
# to avoid any cases where the output size is dependent on the input value.
29+
# https://github.com/adrhill/SparseConnectivityTracer.jl/pull/234#discussion_r2031038566
30+
31+
function _sct_interpolate(
32+
::AbstractInterpolation,
33+
uType::Type{<:AbstractVector{<:Number}},
34+
t::GradientTracer,
35+
is_der_1_zero,
36+
is_der_2_zero,
37+
)
38+
return gradient_tracer_1_to_1(t, is_der_1_zero)
39+
end
40+
function _sct_interpolate(
41+
::AbstractInterpolation,
42+
uType::Type{<:AbstractVector{<:Number}},
43+
t::HessianTracer,
44+
is_der_1_zero,
45+
is_der_2_zero,
46+
)
47+
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
48+
end
49+
function _sct_interpolate(
50+
interp::AbstractInterpolation,
51+
uType::Type{<:AbstractMatrix{<:Number}},
52+
t::GradientTracer,
53+
is_der_1_zero,
54+
is_der_2_zero,
55+
)
56+
t = gradient_tracer_1_to_1(t, is_der_1_zero)
57+
N = only(output_size(interp))
58+
return Fill(t, N)
59+
end
60+
function _sct_interpolate(
61+
interp::AbstractInterpolation,
62+
uType::Type{<:AbstractMatrix{<:Number}},
63+
t::HessianTracer,
64+
is_der_1_zero,
65+
is_der_2_zero,
66+
)
67+
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
68+
N = only(output_size(interp))
69+
return Fill(t, N)
70+
end
71+
72+
#===========#
73+
# Overloads #
74+
#===========#
75+
76+
# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
77+
# all interpolations have a non-zero second derivative at some point in the input domain.
78+
79+
for (I, is_der1_zero, is_der2_zero) in (
80+
(:ConstantInterpolation, true, true),
81+
(:LinearInterpolation, false, true),
82+
(:QuadraticInterpolation, false, false),
83+
(:LagrangeInterpolation, false, false),
84+
(:AkimaInterpolation, false, false),
85+
(:QuadraticSpline, false, false),
86+
(:CubicSpline, false, false),
87+
(:BSplineInterpolation, false, false),
88+
(:BSplineApprox, false, false),
89+
(:CubicHermiteSpline, false, false),
90+
(:QuinticHermiteSpline, false, false),
91+
)
92+
@eval function (interp::$(I){uType})(
93+
t::AbstractTracer
94+
) where {uType <: AbstractArray{<:Number}}
95+
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
96+
end
97+
end
98+
99+
# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
100+
for I in (
101+
:LagrangeInterpolation,
102+
:BSplineInterpolation,
103+
:BSplineApprox,
104+
:CubicHermiteSpline,
105+
:QuinticHermiteSpline,
106+
)
107+
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector}
108+
p = interp(primal(d))
109+
t = interp(tracer(d))
110+
return Dual(p, t)
111+
end
112+
113+
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractMatrix}
114+
p = interp(primal(d))
115+
t = interp(tracer(d))
116+
return Dual.(p, t)
117+
end
118+
end
119+
120+
end

src/derivatives.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ function derivative(A, t, order = 1)
66
_extrapolate_derivative_right(A, t, order)
77
else
88
iguess = A.iguesser
9-
(order == 1) ? _derivative(A, t, iguess) :
10-
ForwardDiff.derivative(t -> begin
11-
_derivative(A, t, iguess)
12-
end, t)
9+
if order == 1
10+
return _derivative(A, t, iguess)
11+
end
12+
return ForwardDiff.derivative(t -> begin
13+
-_derivative(A, -t, iguess)
14+
end, -t) # take derivative backwards in t to make it a left rather than right derivative
1315
end
1416
end
1517

@@ -333,9 +335,8 @@ function _derivative(
333335
ducum = (A.c[ax_u..., 2] - A.c[ax_u..., 1]) / (A.k[A.d + 2])
334336
else
335337
for i in 1:(A.h - 1)
336-
ducum = ducum +
337-
sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
338-
(A.k[i + A.d + 1] - A.k[i + 1])
338+
ducum += sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
339+
(A.k[i + A.d + 1] - A.k[i + 1])
339340
end
340341
end
341342
ducum * A.d * scale

test/derivative_tests.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Symbolics
66
using StableRNGs
77
using RegularizationTools
88
using Optim
9-
using ForwardDiff
9+
import ForwardDiff
1010
using LinearAlgebra
1111

1212
function test_derivatives(method; args = [], kwargs = [], name::String)
@@ -35,11 +35,8 @@ function test_derivatives(method; args = [], kwargs = [], name::String)
3535

3636
# Interpolation transition points
3737
for _t in t[2:(end - 1)]
38-
if func isa Union{BSplineInterpolation, BSplineApprox,
39-
CubicHermiteSpline}
40-
fdiff = forward_fdm(5, 1; geom = true)(func, _t)
41-
fdiff2 = forward_fdm(5, 1; geom = true)(t -> derivative(func, t), _t)
42-
elseif func isa SmoothedConstantInterpolation
38+
if func isa Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
39+
# TODO fix interpolations
4340
continue
4441
else
4542
fdiff = backward_fdm(5, 1; geom = true)(func, _t)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SafeTestsets
88
@safetestset "Integral Tests" include("integral_tests.jl")
99
@safetestset "Integral Inverse Tests" include("integral_inverse_tests.jl")
1010
@safetestset "Extrapolation Tests" include("extrapolation_tests.jl")
11+
@safetestset "SparseConnectivityTracer Tests" include("sparseconnectivitytracer_tests.jl")
1112
@safetestset "Online Tests" include("online_tests.jl")
1213
@safetestset "Regularization Smoothing Tests" include("regularization.jl")
1314
@safetestset "Show methods Tests" include("show.jl")

0 commit comments

Comments
 (0)