Skip to content

Commit 4e8baef

Browse files
Merge pull request #18 from JuliaComputing/scalarfalsi
Scalar Falsi
2 parents 9cdbebb + 090deb4 commit 4e8baef

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ForwardDiff = "0.10.3"
1818
RecursiveArrayTools = "2"
1919
Reexport = "0.2"
2020
Setfield = "0.7"
21-
StaticArrays = "1.0"
21+
StaticArrays = "0.12,1.0"
2222
UnPack = "1.0"
2323
julia = "1"
2424

src/scalar.jl

+59-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,
5353
end
5454

5555
# avoid ambiguities
56-
for Alg in [Bisection, Falsi]
56+
for Alg in [Bisection]
5757
@eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
5858
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5959
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
@@ -110,3 +110,61 @@ function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kw
110110

111111
return BracketingSolution(left, right, MAXITERS_EXCEED)
112112
end
113+
114+
function solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
115+
f = Base.Fix2(prob.f, prob.p)
116+
left, right = prob.u0
117+
fl, fr = f(left), f(right)
118+
119+
if iszero(fl)
120+
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
121+
end
122+
123+
i = 1
124+
if !iszero(fr)
125+
while i < maxiters
126+
if nextfloat_tdir(left, prob.u0...) == right
127+
return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
128+
end
129+
mid = (fr * left - fl * right) / (fr - fl)
130+
for i in 1:10
131+
mid = max(left, prevfloat_tdir(mid, prob.u0...))
132+
end
133+
if mid == right || mid == left
134+
break
135+
end
136+
fm = f(mid)
137+
if iszero(fm)
138+
right = mid
139+
break
140+
end
141+
if sign(fl) == sign(fm)
142+
fl = fm
143+
left = mid
144+
else
145+
fr = fm
146+
right = mid
147+
end
148+
i += 1
149+
end
150+
end
151+
152+
while i < maxiters
153+
mid = (left + right) / 2
154+
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
155+
fm = f(mid)
156+
if iszero(fm)
157+
right = mid
158+
fr = fm
159+
elseif sign(fm) == sign(fl)
160+
left = mid
161+
fl = fm
162+
else
163+
right = mid
164+
fr = fm
165+
end
166+
i += 1
167+
end
168+
169+
return BracketingSolution(left, right, MAXITERS_EXCEED)
170+
end

test/runtests.jl

+14
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ end
5656
# Scalar
5757
f, u0 = (u, p) -> u * u - p, 1.0
5858

59+
# NewtonRaphson
5960
g = function (p)
6061
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
6162
sol = solve(probN, NewtonRaphson())
@@ -69,6 +70,19 @@ for p in 1.1:0.1:100.0
6970
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
7071
end
7172

73+
u0 = (1.0, 20.0)
74+
# Falsi
75+
g = function (p)
76+
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
77+
sol = solve(probN, Falsi())
78+
return sol.left
79+
end
80+
81+
for p in 1.1:0.1:100.0
82+
@test g(p) sqrt(p)
83+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
84+
end
85+
7286
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
7387
t = (p) -> [sqrt(p[2] / p[1])]
7488
p = [0.9, 50.0]

0 commit comments

Comments
 (0)