Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid stack overflows with non-standard float types #79

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Files generated by invoking Julia with --code-coverage
*.jl.cov
*.jl.*.cov

# Files generated by invoking Julia with --track-allocation
*.jl.mem

# System-specific files and directories generated by the BinaryProvider and BinDeps packages
# They contain absolute paths specific to the host computer, and so should not be committed
deps/deps.jl
deps/build.log
deps/downloads/
deps/usr/
deps/src/

# Build artifacts for creating documentation generated by the Documenter package
docs/build/
docs/site/

# File generated by Pkg, the package manager, based on a corresponding Project.toml
# It records a fixed state of all packages used by the project. As such, it should not be
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# Standard editor/IDE customizations
.vscode
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
julia = "1.10"

[extras]
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["DoubleFloats", "Test"]
47 changes: 34 additions & 13 deletions src/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,59 @@ module NaNMath
using OpenLibm_jll
const libm = OpenLibm_jll.libopenlibm


for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:lgamma, :log1p)
@eval begin
($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
($f)(x::Real) = ($f)(float(x))
if $f !== :lgamma
($f)(x::Float16) = Float16(($f)(Float32(x)))
function ($f)(x::Real)
xf = float(x)
x === xf && throw(MethodError($f, (x,)))
return ($f)(xf)
end
if $f !== :lgamma
($f)(x) = (Base.$f)(x)
end
end
end
end

for f in (:sqrt,)
@eval ($f)(x) = (Base.$f)(x)
end

for f in (:max, :min)
@eval ($f)(x, y) = (Base.$f)(x, y)
end
sin(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.sin(x) : T(NaN)
cos(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.cos(x) : T(NaN)
tan(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.tan(x) : T(NaN)
asin(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.asin(x)
acos(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.acos(x)
acosh(x::T) where {T<:AbstractFloat} = x < 1 ? T(NaN) : Base.acosh(x)
atanh(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.atanh(x)
log(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log(x)
log2(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log2(x)
log10(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log10(x)
# lgamma does not have a Base version; the MethodError above will suffice
log1p(x::T) where {T<:AbstractFloat} = x < -1 ? T(NaN) : Base.log1p(x)

# Would be more efficient to remove the domain check in Base.sqrt(),
# but this doesn't seem easy to do.
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
sqrt(x::Real) = sqrt(float(x))
function sqrt(x::Real)
xf = float(x)
x === xf && throw(MethodError(sqrt, (x,)))
return sqrt(xf)
end

# Don't override built-in ^ operator
pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x, y)
pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y)
pow(x::Float16, y::Float16) = Float16(pow(Float32(x), Float32(y)))
# We `promote` first before converting to floating pointing numbers to ensure that
# e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)`
pow(x::Real, y::Real) = pow(promote(x, y)...)
pow(x::T, y::T) where {T<:Real} = pow(float(x), float(y))
pow(x::Number, y::Number) = pow(promote(x, y)...)
function pow(x::T, y::T) where {T<:Number}
yf = float(y)
xf = float(x)
x === xf && y === yf && throw(MethodError(pow, (x,y)))
return pow(xf, yf)
end
pow(x, y) = ^(x, y)

# The following combinations are safe, so we can fall back to ^
Expand Down
62 changes: 60 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,52 @@
using NaNMath
using Test
using DoubleFloats

for T in (Float64, Float32, Float16, BigFloat)
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:log1p) # Note: do :lgamma separately because it can't handle BigFloat
@eval begin
@test NaNMath.$f($T(2//3)) isa $T
@test NaNMath.$f($T(3//2)) isa $T
@test NaNMath.$f($T(-2//3)) isa $T
@test NaNMath.$f($T(-3//2)) isa $T
@test NaNMath.$f($T(Inf)) isa $T
@test NaNMath.$f($T(-Inf)) isa $T
end
end
end
for T in (Float64, Float32, Float16)
@test NaNMath.lgamma(T(2//3)) isa T
@test NaNMath.lgamma(T(3//2)) isa T
@test NaNMath.lgamma(T(-2//3)) isa T
@test NaNMath.lgamma(T(-3//2)) isa T
@test NaNMath.lgamma(T(Inf)) isa T
@test NaNMath.lgamma(T(-Inf)) isa T
end
@test_throws MethodError NaNMath.lgamma(BigFloat(2//3))

@test isnan(NaNMath.log(-10))
@test isnan(NaNMath.log(-10f0))
@test isnan(NaNMath.log(Float16(-10)))
@test isnan(NaNMath.log1p(-100))
@test isnan(NaNMath.log1p(-100f0))
@test isnan(NaNMath.log1p(Float16(-100)))
@test isnan(NaNMath.pow(-1.5,2.3))
@test isnan(NaNMath.pow(-1.5f0,2.3f0))
@test isnan(NaNMath.pow(-1.5,2.3f0))
@test isnan(NaNMath.pow(-1.5f0,2.3))
@test isnan(NaNMath.pow(Float16(-1.5),Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3))
@test isnan(NaNMath.pow(-1.5,Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Float16(2.3)))
@test isnan(NaNMath.pow(-1.5f0,BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),2.3f0))
@test NaNMath.pow(-1,2) isa Float64
@test NaNMath.pow(-1.5f0,2) isa Float32
@test NaNMath.pow(-1.5f0,2//1) isa Float32
@test NaNMath.pow(-1.5f0,2.3f0) isa Float32
Expand All @@ -15,16 +55,34 @@ using Test
@test NaNMath.pow(-1.5,2//1) isa Float64
@test NaNMath.pow(-1.5,2.3f0) isa Float64
@test NaNMath.pow(-1.5,2.3) isa Float64
@test NaNMath.pow(Float16(-1.5),2.3) isa Float64
@test NaNMath.pow(Float16(-1.5),Float16(2.3)) isa Float16
@test NaNMath.pow(-1.5,Float16(2.3)) isa Float64
@test NaNMath.pow(Float16(-1.5),2.3f0) isa Float32
@test NaNMath.pow(-1.5f0,Float16(2.3)) isa Float32
@test NaNMath.pow(-1.5f0,BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),2.3f0) isa BigFloat
@test NaNMath.pow(-1.5f0,Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),2.3f0) isa Double64
@test NaNMath.sqrt(-5) isa Float64
@test NaNMath.pow(-1,2) === 1
@test NaNMath.pow(2,2) === 4
@test NaNMath.pow(1.0, 1.0+im) === 1.0 + 0.0im
@test NaNMath.pow(1.0+im, 1) === 1.0 + 1.0im
@test NaNMath.pow(1.0+im, 1.0) === 1.0 + 1.0im
@test isnan(NaNMath.sqrt(-5))
@test NaNMath.sqrt(5) == Base.sqrt(5)
@test NaNMath.sqrt(-5f0) isa Float32
@test NaNMath.sqrt(5f0) == Base.sqrt(5f0)
@test NaNMath.sqrt(Float16(-5)) isa Float16
@test NaNMath.sqrt(Float16(5)) == Base.sqrt(Float16(5))
@test NaNMath.sqrt(BigFloat(-5)) isa BigFloat
@test NaNMath.sqrt(BigFloat(5)) == Base.sqrt(BigFloat(5))
@test isnan(NaNMath.sqrt(-3.2f0)) && NaNMath.sqrt(-3.2f0) isa Float32
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@inferred NaNMath.sqrt(5)
@inferred NaNMath.sqrt(5.0)
@inferred NaNMath.sqrt(5.0f0)
Expand Down
Loading