Skip to content

Commit d746f22

Browse files
committed
Add Halley and Householder to SimpleNonlinearSolve
1 parent 83f8c3d commit d746f22

File tree

6 files changed

+162
-16
lines changed

6 files changed

+162
-16
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2727
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2828
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2929
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
30+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
3031
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3132

3233
[extensions]
3334
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
3435
SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase"
3536
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
37+
SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff"
3638
SimpleNonlinearSolveTrackerExt = "Tracker"
3739

3840
[compat]
@@ -66,6 +68,7 @@ SciMLBase = "2.58"
6668
Setfield = "1.1.1"
6769
StaticArrays = "1.9"
6870
StaticArraysCore = "1.4.3"
71+
TaylorDiff = "0.3"
6972
Test = "1.10"
7073
TestItemRunner = "1"
7174
Tracker = "0.2.35"
@@ -84,10 +87,11 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
8487
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8588
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8689
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
90+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
8791
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8892
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
8993
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9094
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9195

9296
[targets]
93-
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
97+
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "TaylorDiff", "Test", "TestItemRunner", "Tracker", "Zygote"]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module SimpleNonlinearSolveTaylorDiffExt
2+
using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleHouseholder, Utils
3+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
4+
AbstractNonlinearSolveAlgorithm
5+
using MaybeInplace: @bb
6+
using FastClosures: @closure
7+
import SciMLBase
8+
import TaylorDiff
9+
10+
SimpleNonlinearSolve.is_extension_loaded(::Val{:TaylorDiff}) = true
11+
12+
const NLBUtils = NonlinearSolveBase.Utils
13+
14+
@inline function __get_higher_order_derivatives(
15+
::SimpleHouseholder{N}, prob, x, fx) where {N}
16+
vN = Val(N)
17+
l = map(one, x)
18+
t = TaylorDiff.make_seed(x, l, vN)
19+
20+
if SciMLBase.isinplace(prob)
21+
bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N})
22+
prob.f(bundle, t, prob.p)
23+
map!(TaylorDiff.value, fx, bundle)
24+
else
25+
bundle = prob.f(t, prob.p)
26+
fx = map(TaylorDiff.value, bundle)
27+
end
28+
invbundle = inv.(bundle)
29+
num = N == 1 ? map(TaylorDiff.value, invbundle) :
30+
TaylorDiff.extract_derivative(invbundle, Val(N - 1))
31+
den = TaylorDiff.extract_derivative(invbundle, vN)
32+
return num, den, fx
33+
end
34+
35+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N},
36+
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
37+
termination_condition = nothing, alias_u0 = false, kwargs...) where {N}
38+
length(prob.u0) == 1 ||
39+
throw(ArgumentError("SimpleHouseholder only supports scalar problems"))
40+
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
41+
fx = NLBUtils.evaluate_f(prob, x)
42+
43+
iszero(fx) &&
44+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
45+
46+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
47+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
48+
49+
@bb xo = similar(x)
50+
51+
for i in 1:maxiters
52+
@bb copyto!(xo, x)
53+
num, den, fx = __get_higher_order_derivatives(alg, prob, x, fx)
54+
@bb x .+= N .* num ./ den
55+
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
56+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
57+
end
58+
59+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
60+
end
61+
62+
function SimpleNonlinearSolve.evaluate_hvvp_internal(hvvp, prob::ImmutableNonlinearProblem, u, a)
63+
if SciMLBase.isinplace(prob)
64+
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
65+
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))
66+
else
67+
unary_f = Base.Fix2(prob.f, prob.p)
68+
hvvp = TaylorDiff.derivative(unary_f, u, a, Val(2))
69+
end
70+
hvvp
71+
end
72+
73+
end

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ include("utils.jl")
4949
include("broyden.jl")
5050
include("dfsane.jl")
5151
include("halley.jl")
52+
include("householder.jl")
5253
include("klement.jl")
5354
include("lbroyden.jl")
5455
include("raphson.jl")
@@ -128,6 +129,13 @@ end
128129

129130
function solve_adjoint_internal end
130131

132+
function evaluate_hvvp(args...; kws...)
133+
is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...)
134+
error("Halley's mathod with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
135+
end
136+
137+
function evaluate_hvvp_internal end
138+
131139
@setup_workload begin
132140
for T in (Float64,)
133141
prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
@@ -161,7 +169,7 @@ end
161169
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
162170
export SimpleDFSane
163171
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
164-
export SimpleHalley
172+
export SimpleHalley, SimpleHouseholder
165173

166174
export solve
167175

lib/SimpleNonlinearSolve/src/halley.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
SimpleHalley(autodiff)
3-
SimpleHalley(; autodiff = nothing)
2+
SimpleHalley(autodiff, taylor_mode)
3+
SimpleHalley(; autodiff = nothing, taylor_mode = Val(false))
44
55
A low-overhead implementation of Halley's Method.
66
@@ -15,16 +15,18 @@ A low-overhead implementation of Halley's Method.
1515
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
1616
automatic backend selection). Valid choices include jacobian backends from
1717
`DifferentiationInterface.jl`.
18+
- `taylor_mode`: whether to use Taylor mode automatic differentiation to compute the Hessian-vector-vector product. Defaults to `Val(false)`. If `Val(true)`, you must have `TaylorDiff.jl` loaded.
1819
"""
1920
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
2021
autodiff = nothing
22+
taylor_mode = Val(false)
2123
end
2224

2325
function SciMLBase.__solve(
24-
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
26+
prob::ImmutableNonlinearProblem, alg::SimpleHalley{ad, Val{taylor_mode}}, args...;
2527
abstol = nothing, reltol = nothing, maxiters = 1000,
2628
alias_u0 = false, termination_condition = nothing, kwargs...
27-
)
29+
) where {ad, taylor_mode}
2830
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
2931
fx = NLBUtils.evaluate_f(prob, x)
3032
T = promote_type(eltype(fx), eltype(x))
@@ -50,8 +52,19 @@ function SciMLBase.__solve(
5052
A, Aaᵢ, cᵢ = x, x, x
5153
end
5254

55+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
56+
NLBUtils.safe_similar(fx) : fx
57+
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
58+
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
59+
5360
for _ in 1:maxiters
54-
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
61+
if taylor_mode
62+
fx = NLBUtils.evaluate_f!!(prob, fx, x)
63+
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
64+
H = nothing
65+
else
66+
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
67+
end
5568

5669
NLBUtils.can_setindex(x) || (A = J)
5770

@@ -67,12 +80,17 @@ function SciMLBase.__solve(
6780
end
6881

6982
aᵢ = J_fact \ NLBUtils.safe_vec(fx)
70-
A_ = NLBUtils.safe_vec(A)
71-
@bb A_ = H × aᵢ
72-
A = NLBUtils.restructure(A, A_)
7383

74-
@bb Aaᵢ = A × aᵢ
75-
@bb A .*= -1
84+
if taylor_mode
85+
Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ))
86+
else
87+
A_ = NLBUtils.safe_vec(A)
88+
@bb A_ = H × aᵢ
89+
A = NLBUtils.restructure(A, A_)
90+
91+
@bb Aaᵢ = A × aᵢ
92+
@bb A .*= -1
93+
end
7694
bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ)
7795

7896
cᵢ_ = NLBUtils.safe_vec(cᵢ)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
SimpleHouseholder{order}()
3+
4+
A low-overhead implementation of Householder's method to arbitrary order.
5+
This method is non-allocating on scalar and static array problems.
6+
7+
!!! warning
8+
9+
Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality.
10+
Internally, this uses TaylorDiff.jl for automatic differentiation.
11+
12+
### Type Parameters
13+
14+
- `order`: the order of the Householder method. `order = 1` is the same as Newton's method, `order = 2` is the same as Halley's method, etc.
15+
"""
16+
struct SimpleHouseholder{order} <: AbstractSimpleNonlinearSolveAlgorithm end

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testsnippet RootfindTestSnippet begin
22
using StaticArrays, Random, LinearAlgebra, ForwardDiff, NonlinearSolveBase, SciMLBase
33
using ADTypes, PolyesterForwardDiff, Enzyme, ReverseDiff
4+
import TaylorDiff
45

56
quadratic_f(u, p) = u .* u .- p
67
quadratic_f!(du, u, p) = (du .= u .* u .- p)
@@ -82,21 +83,47 @@ end
8283
AutoFiniteDiff(),
8384
AutoReverseDiff(),
8485
nothing
85-
)
86+
), taylor_mode in (Val(false), Val(true))
8687
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
8788
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
88-
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff))
89+
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff, taylor_mode))
90+
@test SciMLBase.successful_retcode(sol)
91+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
92+
end
93+
end
94+
95+
@testset for taylor_mode in (Val(false), Val(true))
96+
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
97+
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
98+
99+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
100+
@test all(solve(
101+
probN, alg(; autodiff = AutoForwardDiff(), taylor_mode); termination_condition).u .≈
102+
sqrt(2.0))
103+
end
104+
end
105+
end
106+
end
107+
108+
@testitem "Higher Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
109+
@testset for alg in (
110+
SimpleHouseholder,
111+
)
112+
@testset for order in (1, 2, 3, 4)
113+
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
114+
[1.0], @SVector[1.0], 1.0)
115+
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg{order}())
89116
@test SciMLBase.successful_retcode(sol)
90117
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
91118
end
92119
end
93120

94121
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
95-
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
122+
u0 in (1.0, [1.0], @SVector[1.0])
96123

97124
probN = NonlinearProblem(quadratic_f, u0, 2.0)
98125
@test all(solve(
99-
probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈
126+
probN, alg{2}(); termination_condition).u .≈
100127
sqrt(2.0))
101128
end
102129
end

0 commit comments

Comments
 (0)