Skip to content

Commit 5f8d644

Browse files
committed
Fix API issue
1 parent d746f22 commit 5f8d644

File tree

4 files changed

+26
-24
lines changed

4 files changed

+26
-24
lines changed

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHousehold
5959
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
6060
end
6161

62-
function SimpleNonlinearSolve.evaluate_hvvp_internal(hvvp, prob::ImmutableNonlinearProblem, u, a)
62+
function SimpleNonlinearSolve.evaluate_hvvp_internal(
63+
hvvp, prob::ImmutableNonlinearProblem, u, a)
6364
if SciMLBase.isinplace(prob)
6465
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
6566
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function solve_adjoint_internal end
131131

132132
function evaluate_hvvp(args...; kws...)
133133
is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...)
134-
error("Halley's mathod with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
134+
error("Halley's method with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
135135
end
136136

137137
function evaluate_hvvp_internal end

lib/SimpleNonlinearSolve/src/halley.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
SimpleHalley(autodiff, taylor_mode)
3-
SimpleHalley(; autodiff = nothing, taylor_mode = Val(false))
2+
SimpleHalley(autodiff)
3+
SimpleHalley(; autodiff = nothing)
44
55
A low-overhead implementation of Halley's Method.
66
@@ -15,18 +15,17 @@ A low-overhead implementation of Halley's Method.
1515
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
1616
automatic backend selection). Valid choices include jacobian backends from
1717
`DifferentiationInterface.jl`.
18-
- `taylor_mode`: whether to use Taylor mode automatic differentiation to compute the Hessian-vector-vector product. Defaults to `Val(false)`. If `Val(true)`, you must have `TaylorDiff.jl` loaded.
18+
In addition, `AutoTaylorDiff` can be used to enable Taylor mode for computing the Hessian-vector-vector product more efficiently; in this case, the Jacobian would still be calculated using the default backend. You need to have `TaylorDiff.jl` loaded to use this option.
1919
"""
2020
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
2121
autodiff = nothing
22-
taylor_mode = Val(false)
2322
end
2423

2524
function SciMLBase.__solve(
26-
prob::ImmutableNonlinearProblem, alg::SimpleHalley{ad, Val{taylor_mode}}, args...;
25+
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
2726
abstol = nothing, reltol = nothing, maxiters = 1000,
2827
alias_u0 = false, termination_condition = nothing, kwargs...
29-
) where {ad, taylor_mode}
28+
)
3029
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
3130
fx = NLBUtils.evaluate_f(prob, x)
3231
T = promote_type(eltype(fx), eltype(x))
@@ -40,6 +39,7 @@ function SciMLBase.__solve(
4039

4140
# The way we write the 2nd order derivatives, we know Enzyme won't work there
4241
autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff
42+
jac_autodiff = autodiff === AutoTaylorDiff() ? AutoForwardDiff() : autodiff
4343
@set! alg.autodiff = autodiff
4444

4545
@bb xo = copy(x)
@@ -54,16 +54,16 @@ function SciMLBase.__solve(
5454

5555
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
5656
NLBUtils.safe_similar(fx) : fx
57-
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
58-
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
57+
jac_cache = Utils.prepare_jacobian(prob, jac_autodiff, fx_cache, x)
58+
J = Utils.compute_jacobian!!(nothing, prob, jac_autodiff, fx_cache, x, jac_cache)
5959

6060
for _ in 1:maxiters
61-
if taylor_mode
61+
if autodiff isa AutoTaylorDiff
6262
fx = NLBUtils.evaluate_f!!(prob, fx, x)
63-
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
63+
J = Utils.compute_jacobian!!(J, prob, jac_autodiff, fx_cache, x, jac_cache)
6464
H = nothing
6565
else
66-
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
66+
fx, J, H = Utils.compute_jacobian_and_hessian(jac_autodiff, prob, fx, x)
6767
end
6868

6969
NLBUtils.can_setindex(x) || (A = J)
@@ -81,7 +81,7 @@ function SciMLBase.__solve(
8181

8282
aᵢ = J_fact \ NLBUtils.safe_vec(fx)
8383

84-
if taylor_mode
84+
if autodiff isa AutoTaylorDiff
8585
Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ))
8686
else
8787
A_ = NLBUtils.safe_vec(A)

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,26 @@ end
8282
AutoForwardDiff(),
8383
AutoFiniteDiff(),
8484
AutoReverseDiff(),
85+
AutoTaylorDiff(),
8586
nothing
86-
), taylor_mode in (Val(false), Val(true))
87+
)
8788
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
8889
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
89-
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff, taylor_mode))
90+
sol = run_nlsolve_oop(
91+
quadratic_f, u0; solver = alg(; autodiff))
9092
@test SciMLBase.successful_retcode(sol)
9193
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
9294
end
9395
end
9496

95-
@testset for taylor_mode in (Val(false), Val(true))
96-
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
97-
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
97+
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
98+
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
9899

99-
probN = NonlinearProblem(quadratic_f, u0, 2.0)
100-
@test all(solve(
101-
probN, alg(; autodiff = AutoForwardDiff(), taylor_mode); termination_condition).u .≈
102-
sqrt(2.0))
103-
end
100+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
101+
@test all(solve(
102+
probN, alg(; autodiff = AutoTaylorDiff());
103+
termination_condition).u .≈
104+
sqrt(2.0))
104105
end
105106
end
106107
end

0 commit comments

Comments
 (0)