Skip to content

Commit a1c4fda

Browse files
authored
Improve support for Float16 (#74)
* Improve support for Float16 * Base `cexpexp(::Float16)` on `Float32` computations * Test only non-broken Julia versions * Update Project.toml
1 parent cbd1bbe commit a1c4fda

File tree

4 files changed

+42
-19
lines changed

4 files changed

+42
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.25"
4+
version = "0.3.26"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/LogExpFunctions.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp,
1010
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
1111
softmax!, logcosh, logabssinh, cloglog, cexpexp
1212

13+
# expm1(::Float16) is not defined in older Julia versions,
14+
# hence for better Float16 support we use an internal function instead
15+
# https://github.com/JuliaLang/julia/pull/40867
16+
if VERSION < v"1.7.0-DEV.1172"
17+
_expm1(x) = expm1(x)
18+
_expm1(x::Float16) = Float16(expm1(Float32(x)))
19+
else
20+
const _expm1 = expm1
21+
end
22+
1323
include("basicfuns.jl")
1424
include("logsumexp.jl")
1525

src/basicfuns.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ function log1mexp(x::Real)
237237
if x < oftype(float(x), IrrationalConstants.loghalf)
238238
return log1p(-exp(x))
239239
else
240-
return log(-expm1(x))
240+
return log(-_expm1(x))
241241
end
242242
end
243243

@@ -246,15 +246,15 @@ $(SIGNATURES)
246246
247247
Return `log(2 - exp(x))` evaluated as `log1p(-expm1(x))`
248248
"""
249-
log2mexp(x::Real) = log1p(-expm1(x))
249+
log2mexp(x::Real) = log1p(-_expm1(x))
250250

251251
"""
252252
$(SIGNATURES)
253253
254254
Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse of
255255
[`log1pexp`](@ref) (aka “softplus”).
256256
"""
257-
logexpm1(x::Real) = x <= 18.0 ? log(expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
257+
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
258258
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)
259259

260260
const softplus = log1pexp
@@ -440,4 +440,4 @@ $(SIGNATURES)
440440
441441
Compute the complementary double exponential, `1 - exp(-exp(x))`.
442442
"""
443-
cexpexp(x) = -expm1(-exp(x))
443+
cexpexp(x) = -_expm1(-exp(x))

test/basicfuns.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,28 @@ end
159159
end
160160

161161
@testset "log1mexp" begin
162-
@test log1mexp(-1.0) log1p(- exp(-1.0))
163-
@test log1mexp(-10.0) log1p(- exp(-10.0))
162+
for T in (Float64, Float32, Float16)
163+
@test @inferred(log1mexp(-T(1))) isa T
164+
@test log1mexp(-T(1)) log1p(- exp(-T(1)))
165+
@test log1mexp(-T(10)) log1p(- exp(-T(10)))
166+
end
164167
end
165168

166169
@testset "log2mexp" begin
167-
@test log2mexp(0.0) 0.0
168-
@test log2mexp(-1.0) log(2.0 - exp(-1.0))
170+
for T in (Float64, Float32, Float16)
171+
@test @inferred(log2mexp(T(0))) isa T
172+
@test iszero(log2mexp(T(0)))
173+
@test log2mexp(-T(1)) log(2 - exp(-T(1)))
174+
end
169175
end
170176

171177
@testset "logexpm1" begin
172-
@test logexpm1(2.0) log(exp(2.0) - 1.0)
173-
@test logexpm1(log1pexp(2.0)) 2.0
174-
@test logexpm1(log1pexp(-2.0)) -2.0
175-
176-
@test logexpm1(2f0) log(exp(2f0) - 1f0)
177-
@test logexpm1(log1pexp(2f0)) 2f0
178-
@test logexpm1(log1pexp(-2f0)) -2f0
178+
for T in (Float64, Float32, Float16)
179+
@test @inferred(logexpm1(T(2))) isa T
180+
@test logexpm1(T(2)) log(exp(T(2)) - 1)
181+
@test logexpm1(log1pexp(T(2))) T(2)
182+
@test logexpm1(log1pexp(-T(2))) -T(2)
183+
end
179184
end
180185

181186
@testset "log1pmx" begin
@@ -433,9 +438,17 @@ end
433438
cloglog_big(x::T) where {T} = T(log(-log(1 - BigFloat(x))))
434439
cexpexp_big(x::T) where {T} = 1 - exp(-exp(BigFloat(x)))
435440

436-
for x in 0.1:0.1:0.9
437-
@test cloglog(x) cloglog_big(x)
438-
@test cexpexp(x) cexpexp_big(x)
441+
for T in (Float64, Float32, Float16)
442+
@test @inferred(cloglog(T(1//2))) isa T
443+
@test @inferred(cexpexp(T(0))) isa T
444+
for x in 0.1:0.1:0.9
445+
@test cloglog(T(x)) cloglog_big(T(x))
446+
# Julia bug for Float32 and Float16 initially introduced in https://github.com/JuliaLang/julia/pull/37440
447+
# and fixed in https://github.com/JuliaLang/julia/pull/50989
448+
if T === Float64 || VERSION < v"1.7.0-DEV.887" || VERSION >= v"1.11.0-DEV.310"
449+
@test cexpexp(T(x)) cexpexp_big(T(x))
450+
end
451+
end
439452
end
440453
for _ in 1:10
441454
randf = rand(Float64)

0 commit comments

Comments
 (0)