Skip to content

Commit 172369a

Browse files
committed
feat: support logsoftmax
1 parent 6aab7f7 commit 172369a

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
3232
return out ./= tmp
3333
end
3434

35+
function NNlib.logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims=1) where {T}
36+
max_ = NNlib.fast_maximum(x; dims)
37+
# if all(isfinite, max_)
38+
@fastmath out .= x .- max_
39+
# else
40+
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
41+
# @. out = ifelse(
42+
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
43+
# )
44+
# end
45+
@fastmath log_ = log.(sum(exp, out; dims))
46+
return out .-= log_
47+
end
48+
3549
function NNlib.conv(
3650
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
3751
) where {T,N}

0 commit comments

Comments
 (0)