Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make ia_solve modular #1348

Merged
merged 9 commits into from
Nov 6, 2024
9 changes: 9 additions & 0 deletions docs/src/manual/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ to `solve_univar`. We can see that essentially, `solve_univar` is the building b
it to `ia_solve`, which attempts solving by attraction and isolation [^2]. This only works when the input is a single expression
and the user wants the answer in terms of a single variable. Say `log(x) - a == 0` gives us `[e^a]`.

```@docs
Symbolics.solve_univar
Symbolics.solve_multivar
Symbolics.ia_solve
Symbolics.ia_conditions!
Symbolics.is_periodic
Symbolics.fundamental_period
```

#### Nice examples

```@example solver
Expand Down
4 changes: 4 additions & 0 deletions src/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ inverse(::typeof(NaNMath.log10)) = inverse(log10)
inverse(::typeof(NaNMath.log1p)) = inverse(log1p)
inverse(::typeof(NaNMath.log2)) = inverse(log2)
left_inverse(::typeof(NaNMath.sqrt)) = left_inverse(sqrt)
# inverses of solve helpers
left_inverse(::typeof(ssqrt)) = left_inverse(sqrt)
left_inverse(::typeof(scbrt)) = left_inverse(cbrt)
left_inverse(::typeof(slog)) = left_inverse(log)

function inverse(f::ComposedFunction)
return inverse(f.inner) ∘ inverse(f.outer)
Expand Down
82 changes: 82 additions & 0 deletions src/solver/ia_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,85 @@ function find_logandexpon(arg, var, oper, poly_index)
!isequal(oper_term, 0) && !isequal(constant_term, 0) && return true
return false
end

"""
ia_conditions!(f, lhs, rhs::Vector{Any}, conditions::Vector{Tuple})

If `f` is a left-invertible function, `lhs` and `rhs[i]` are univariate functions and
`f(lhs) ~ rhs[i]` for all `i in eachindex(rhss)`, push to `conditions` all the relevant
conditions on `lhs` or `rhs[i]`. Each condition is of the form `(sym, op)` where `sym`
is an expression involving `lhs` and/or `rhs[i]` and `op` is a binary relational operator.
The condition `op(sym, 0)` is then required to be true for the equation `f(lhs) ~ rhs[i]`
to be valid.

For example, if `f = log`, `lhs = x` and `rhss = [y, z]` then the condition `x > 0` must
be true. Thus, `(lhs, >)` is pushed to `conditions`. Similarly, if `f = sqrt`, `rhs[i] >= 0`
must be true for all `i`, and so `(y, >=)` and `(z, >=)` will be appended to `conditions`.
"""
function ia_conditions!(args...; kwargs...) end

for fn in [log, log2, log10, NaNMath.log, NaNMath.log2, NaNMath.log10, slog]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
push!(conditions, (lhs, >))
end
end

for fn in [log1p, NaNMath.log1p]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
push!(conditions, (lhs - 1, >))
end
end

for fn in [sqrt, NaNMath.sqrt, ssqrt]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
for r in rhs
push!(conditions, (r, >=))
end
end
end

"""
is_periodic(f)

Return `true` if `f` is a single-input single-output periodic function. Return `false` by
default. If `is_periodic(f) == true`, then `fundamental_period(f)` must also be defined.

See also: [`fundamental_period`](@ref)
"""
is_periodic(f) = false

for fn in [
sin, cos, tan, csc, sec, cot, NaNMath.sin, NaNMath.cos, NaNMath.tan, sind, cosd, tand,
cscd, secd, cotd, cospi
]
@eval is_periodic(::typeof($fn)) = true
end

"""
fundamental_period(f)

Return the fundamental period of periodic function `f`. Must only be called if
`is_periodic(f) == true`.

see also: [`is_periodic`](@ref)
"""
function fundamental_period end

for fn in [sin, cos, csc, sec, NaNMath.sin, NaNMath.cos]
@eval fundamental_period(::typeof($fn)) = 2pi
end

for fn in [sind, cosd, cscd, secd]
@eval fundamental_period(::typeof($fn)) = 360.0
end

fundamental_period(::typeof(cospi)) = 2.0

for fn in [tand, cotd]
@eval fundamental_period(::typeof($fn)) = 180.0
Comment on lines +212 to +218
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason that these are floats? any downsides for making them ints? curious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was that due to how BasicSymbolic works, calling fundamental_period will always be dynamic dispatch but if all the return types are Float64, Julia should be able to infer that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks

end

for fn in [tan, cot, NaNMath.tan]
# `1pi isa Float64` whereas `pi isa Irrational{:π}`
@eval fundamental_period(::typeof($fn)) = 1pi
end
126 changes: 61 additions & 65 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Symbolics

function isolate(lhs, var; warns=true, conditions=[])
const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt)

function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, periodic_roots = true)
rhs = Vector{Any}([0])
original_lhs = deepcopy(lhs)
lhs = unwrap(lhs)
Expand Down Expand Up @@ -72,12 +74,21 @@ function isolate(lhs, var; warns=true, conditions=[])
power = args[2]
new_roots = []

for i in eachindex(rhs)
for k in 0:(args[2] - 1)
r = wrap(term(^, rhs[i], (1 // power)))
c = wrap(term(*, 2 * (k), pi)) * im / power
root = r * Base.MathConstants.e^c
push!(new_roots, root)
if complex_roots
for i in eachindex(rhs)
for k in 0:(args[2] - 1)
r = term(^, rhs[i], (1 // power))
c = term(*, 2 * (k), pi) * im / power
root = r * Base.MathConstants.e^c
push!(new_roots, root)
end
end
else
for i in eachindex(rhs)
push!(new_roots, term(^, rhs[i], (1 // power)))
if iseven(power)
push!(new_roots, term(-, new_roots[end]))
end
end
end
rhs = []
Expand All @@ -90,57 +101,23 @@ function isolate(lhs, var; warns=true, conditions=[])
lhs = args[2]
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)
end

elseif oper === (log) || oper === (slog)
lhs = args[1]
rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (log2)
lhs = args[1]
rhs = map(sol -> term(^, 2, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (log10)
elseif has_left_inverse(oper)
lhs = args[1]
rhs = map(sol -> term(^, 10, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (sqrt)
lhs = args[1]
append!(conditions, [(r, >=) for r in rhs])
rhs = map(sol -> term(^, sol, 2), rhs)

elseif oper === (cbrt)
lhs = args[1]
rhs = map(sol -> term(^, sol, 3), rhs)

elseif oper === (sin) || oper === (cos) || oper === (tan)
rev_oper = Dict(sin => asin, cos => acos, tan => atan)
lhs = args[1]
# make this global somehow so the user doesnt need to declare it on his own
new_var = gensym()
new_var = (@variables $new_var)[1]
rhs = map(
sol -> term(rev_oper[oper], sol) +
term(*, Base.MathConstants.pi, new_var),
rhs)
@info string(new_var) * " ϵ" * " Ζ"

elseif oper === (asin)
lhs = args[1]
rhs = map(sol -> term(sin, sol), rhs)

elseif oper === (acos)
lhs = args[1]
rhs = map(sol -> term(cos, sol), rhs)

elseif oper === (atan)
lhs = args[1]
rhs = map(sol -> term(tan, sol), rhs)
elseif oper === (exp)
lhs = args[1]
rhs = map(sol -> term(slog, sol), rhs)
ia_conditions!(oper, lhs, rhs, conditions)
invop = left_inverse(oper)
invop = get(SAFE_ALTERNATIVES, invop, invop)
if is_periodic(oper) && periodic_roots
new_var = gensym()
new_var = (@variables $new_var)[1]
period = fundamental_period(oper)
rhs = map(
sol -> term(invop, sol) +
term(*, period, new_var),
rhs)
@info string(new_var) * " ϵ" * " Ζ"
else
rhs = map(sol -> term(invop, sol), rhs)
end
end

lhs = simplify(lhs)
Expand All @@ -149,7 +126,7 @@ function isolate(lhs, var; warns=true, conditions=[])
return rhs, conditions
end

function attract(lhs, var; warns = true)
function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
if n_func_occ(simplify(lhs), var) <= n_func_occ(lhs, var)
lhs = simplify(lhs)
end
Expand All @@ -164,7 +141,9 @@ function attract(lhs, var; warns = true)
end
lhs = attract_trig(lhs, var)

n_func_occ(lhs, var) == 1 && return isolate(lhs, var, warns = warns, conditions=conditions)
if n_func_occ(lhs, var) == 1
return isolate(lhs, var; warns, conditions, complex_roots, periodic_roots)
end

lhs, sub = turn_to_poly(lhs, var)

Expand All @@ -182,12 +161,12 @@ function attract(lhs, var; warns = true)
new_var = collect(keys(sub))[1]
new_var_val = collect(values(sub))[1]

roots, new_conds = isolate(lhs, new_var, warns = warns)
roots, new_conds = isolate(lhs, new_var; warns = warns, complex_roots, periodic_roots)
append!(conditions, new_conds)
new_roots = []

for root in roots
new_sol, new_conds = isolate(new_var_val - root, var, warns = warns)
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
append!(conditions, new_conds)
push!(new_roots, new_sol)
end
Expand All @@ -197,7 +176,7 @@ function attract(lhs, var; warns = true)
end

"""
ia_solve(lhs, var)
ia_solve(lhs, var; kwargs...)
This function attempts to solve transcendental functions by first checking
the "smart" number of occurrences in the input LHS. By smart here we mean
that polynomials are counted as 1 occurrence. for example `x^2 + 2x` is 1
Expand Down Expand Up @@ -226,6 +205,13 @@ we throw an error to tell the user that this is currently unsolvable by our cove
- lhs: a Num/SymbolicUtils.BasicSymbolic
- var: variable to solve for.

# Keyword arguments
- `warns = true`: Whether to emit warnings for unsolvable expressions.
- `complex_roots = true`: Whether to consider complex roots of `x ^ n ~ y`, where `n` is an integer.
- `periodic_roots = true`: If `true`, isolate `f(x) ~ y` as `x ~ finv(y) + n * period` where
`is_periodic(f) == true`, `finv = left_inverse(f)` and `period = fundamental_period(f)`. `n`
is a new anonymous symbolic variable.

# Examples
```jldoctest
julia> solve(a*x^b + c, x)
Expand Down Expand Up @@ -256,20 +242,30 @@ julia> RootFinding.ia_solve(expr, x)
-2 + π*2var"##230" + asin((1//2)*(-1 + RootFinding.ssqrt(-39)))
-2 + π*2var"##234" + asin((1//2)*(-1 - RootFinding.ssqrt(-39)))
```

All transcendental functions for which `left_inverse` is defined are supported.
To enable `ia_solve` to handle custom transcendental functions, define an inverse or
left inverse. If the function is periodic, `is_periodic` and `fundamental_period` must
be defined. If the function imposes certain conditions on its input or output (for
example, `log` requires that its input be positive) define `ia_conditions!`.

See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref),
[`fundamental_period`](@ref), [`ia_conditions!`](@ref).

# References
[^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070).
"""
function ia_solve(lhs, var; warns = true)
function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
nx = n_func_occ(lhs, var)
sols = []
conditions = []
if nx == 0
warns && @warn("Var not present in given expression")
return []
elseif nx == 1
sols, conditions = isolate(lhs, var, warns = warns)
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
elseif nx > 1
sols, conditions = attract(lhs, var, warns = warns)
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
end

isequal(sols, nothing) && return nothing
Expand Down
29 changes: 27 additions & 2 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ end
#@test isequal(lhs, rhs)

lhs = symbolic_solve(log(a*x)-b,x)[1]
@test isequal(Symbolics.arguments(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))))[1], E)
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)

expr = x + 2
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
Expand All @@ -414,7 +414,7 @@ end
@test isapprox(eval(Symbolics.toexpr(symbolic_solve(expr, x)[1])), sqrt(2), atol=1e-6)

expr = 2^(x+1) + 5^(x+3)
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
lhs = ComplexF64.(eval.(Symbolics.toexpr.(ia_solve(expr, x))))
lhs_solve = eval.(Symbolics.toexpr.(symbolic_solve(expr, x)))
rhs = [(-im*Base.MathConstants.pi - log(2) + 3log(5))/(log(2) - log(5))]
@test lhs[1] ≈ rhs[1]
Expand Down Expand Up @@ -471,6 +471,31 @@ end

@test all(lhs .≈ rhs)
@test all(lhs_solve .≈ rhs)

@testset "Keyword arguments" begin
expr = sec(x ^ 2 + 4x + 4) ^ 3 - 3
roots = ia_solve(expr, x)
@test length(roots) == 6 # 2 quadratic roots * 3 roots from cbrt(3)
@test length(Symbolics.get_variables(roots[1])) == 1
_n = only(Symbolics.get_variables(roots[1]))
vals = substitute.(roots, (Dict(_n => 0),))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)

roots = ia_solve(expr, x; complex_roots = false)
@test length(roots) == 2
# the `n` in `θ + n * 2π`
@test length(Symbolics.get_variables(roots[1])) == 1
_n = only(Symbolics.get_variables(roots[1]))
vals = substitute.(roots, (Dict(_n => 0),))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)

roots = ia_solve(expr, x; complex_roots = false, periodic_roots = false)
@test length(roots) == 2
@test length(Symbolics.get_variables(roots[1])) == 0
@test length(Symbolics.get_variables(roots[2])) == 0
vals = eval.(Symbolics.toexpr.(roots))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)
end
end

@testset "Sqrt case poly" begin
Expand Down
Loading