Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add discontinuity handling API #1359

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,7 @@ end
export inverse, left_inverse, right_inverse, @register_inverse, has_inverse, has_left_inverse, has_right_inverse
include("inverse.jl")

export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity
include("discontinuities.jl")

end # module
106 changes: 106 additions & 0 deletions src/discontinuities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
rootfunction(f)

Given a function `f` with a discontinuity or discontinuous derivative, return the rootfinding
function of `f`. The rootfinding function `g` takes the same arguments as `f`, and is such
that `f` can be described as a piecewise function based on the sign of `g`, where each piece
is continuous and has a continuous derivative. The pieces are obtained using
`left_continuous_function(f)` and `right_continuous_function(f)`.

More formally,
```julia
f(args...) = if g(args...) < 0
left_continuous_function(f)(args...)
else
right_continuous_function(f)(args...)
end
```

For example, if `f` is `max(x, y)`, the root function is `(x, y) -> x - y` with
`left_continuous_function` as `(x, y) -> y` and `right_continuous_function` as
`(x, y) -> x`.

See also: [`left_continuous_function`](@ref), [`right_continuous_function`](@ref).
"""
function rootfunction end

"""
left_continuous_function(f)

Given a function `f` with a discontinuity or discontinuous derivative, return a function
taking the same arguments as `f` which is continuous and has a continuous derivative
when `rootfinding_function(f)` is negative.

See also: [`rootfunction`](@ref).
"""
function left_continuous_function end

"""
right_continuous_function(f)

Given a function `f` with a discontinuity or discontinuous derivative, return a function
taking the same arguments as `f` which is continuous and has a continuous derivative
when `rootfinding_function(f)` is positive.

See also: [`rootfunction`](@ref).
"""
function right_continuous_function end

"""
@register_discontinuity f(arg1, arg2, ...) root_expr left_expr right_expr

Utility macro to register functions with discontinuities. The function `f` with
arguments `arg1, arg2, ...` has a `rootfunction` of `root_expr`, a
`left_continuous_function` of `left_expr` and `right_continuous_function` of
`right_expr`. `root_expr`, `left_expr` and `right_expr` are all expressions in terms
of `arg1, arg2, ...`.

For example, `max(x, y)` can be registered as `@register_discontinuity max(x, y) x - y y x`.

See also: [`rootfunction`](@ref)
"""
macro register_discontinuity(f, root, left, right)
Meta.isexpr(f, :call) || error("Expected function call as first argument")
args = f.args[2:end]
fn = esc(f.args[1])
rootname = gensym(:root)
rootfn = :(function $rootname($(args...))
$root
end)
leftname = gensym(:left)
leftfn = :(function $leftname($(args...))
$left
end)
rightname = gensym(:right)
rightfn = :(function $rightname($(args...))
$right
end)
return quote
$rootfn
(::$typeof($rootfunction))(::$typeof($fn)) = $rootname
$leftfn
(::$typeof($left_continuous_function))(::$typeof($fn)) = $leftname
$rightfn
(::$typeof($right_continuous_function))(::$typeof($fn)) = $rightname
end
end

# a triangle function which is zero when x is a multiple of period
function _triangle(x, period)
x /= 2period
abs(x + 1 // 4 - floor(x + 3 // 4)) - 1 // 2
end

@register_discontinuity abs(x) x -x x
# just needs a rootfind to hit the discontinuity
@register_discontinuity mod(x, y) _triangle(x, y) mod(x, y) mod(x, y)
@register_discontinuity rem(x, y) _triangle(x, y) rem(x, y) rem(x, y)
@register_discontinuity div(x, y) _triangle(x, y) div(x, y) div(x, y)
@register_discontinuity max(x, y) x - y y x
@register_discontinuity min(x, y) x - y x y
@register_discontinuity NaNMath.max(x, y) x - y y x
@register_discontinuity NaNMath.min(x, y) x - y x y
@register_discontinuity <(x, y) x - y true false
@register_discontinuity <=(x, y) y - x false true
@register_discontinuity >(x, y) y - x true false
@register_discontinuity >=(x, y) x - y false true
22 changes: 12 additions & 10 deletions src/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,26 @@ inverse.
"""
macro register_inverse(f, g, dir::QuoteNode = :(:both))
dir = dir.value
f = esc(f)
g = esc(g)
if dir == :both
quote
(::typeof($inverse))(::typeof($f)) = $g
(::typeof($inverse))(::typeof($g)) = $f
(::typeof($left_inverse))(::typeof($f)) = $(inverse)($f)
(::typeof($right_inverse))(::typeof($f)) = $(inverse)($f)
(::typeof($left_inverse))(::typeof($g)) = $(inverse)($g)
(::typeof($right_inverse))(::typeof($g)) = $(inverse)($g)
(::$typeof($inverse))(::$typeof($f)) = $g
(::$typeof($inverse))(::$typeof($g)) = $f
(::$typeof($left_inverse))(::$typeof($f)) = $(inverse)($f)
(::$typeof($right_inverse))(::$typeof($f)) = $(inverse)($f)
(::$typeof($left_inverse))(::$typeof($g)) = $(inverse)($g)
(::$typeof($right_inverse))(::$typeof($g)) = $(inverse)($g)
end
elseif dir == :left
quote
(::typeof($left_inverse))(::typeof($f)) = $g
(::typeof($right_inverse))(::typeof($g)) = $f
(::$typeof($left_inverse))(::$typeof($f)) = $g
(::$typeof($right_inverse))(::$typeof($g)) = $f
end
elseif dir == :right
quote
(::typeof($right_inverse))(::typeof($f)) = $g
(::typeof($left_inverse))(::typeof($g)) = $f
(::$typeof($right_inverse))(::$typeof($f)) = $g
(::$typeof($left_inverse))(::$typeof($g)) = $f
end
else
throw(ArgumentError("The third argument to `@register_inverse` must be `left` or `right`"))
Expand Down
30 changes: 30 additions & 0 deletions test/discontinuities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Symbolics, NaNMath, Test

function discontinuity_eval(fn, args...)
if rootfunction(fn)(args...) < 0
left_continuous_function(fn)(args...)
else
right_continuous_function(fn)(args...)
end
end

@testset "abs" begin
for x in -1.0:0.001:1.0
@test abs(x) ≈ discontinuity_eval(abs, x)
end
end

@testset "$(nameof(f))" for f in (mod, rem, div)
y = 0.7
for x in -2y:0.001:2y
@test f(x, y) ≈ discontinuity_eval(f, x, y)
end
end

@testset "$(nameof(f))" for f in (min, max, NaNMath.min, NaNMath.max, <, <=, >, >=)
for x in 0.0:0.1:1.0
for y in 0.0:0.1:1.0
@test f(x, y) ≈ discontinuity_eval(f, x, y)
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "RootFinding solver" begin include("solver.jl") end
@safetestset "Function inverses test" begin include("inverse.jl") end
@safetestset "Taylor Series Test" begin include("taylor.jl") end
@safetestset "Discontinuity registration test" begin include("discontinuities.jl") end
end
end

Expand Down
Loading