Skip to content

Commit 29ef2ff

Browse files
committed
revert to unfused broadcast
1 parent 689d16b commit 29ef2ff

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/layers/stateless.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,11 @@ function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims)
4747
return (μ, σ), _mean_std_pullback
4848
end
4949

50-
_zscore(x, μ, σ, ϵ) = (x - μ) /+ ϵ)
51-
5250
# We don't define a rrule for the whole function because we want
53-
# AD to figure out the _zscore broadcast for us.
51+
# AD to figure out the broadcast for us.
5452
function _normalize(x::AbstractArray, dims, ϵ)
5553
μ, σ = _mean_std(x, dims)
56-
return _zscore.(x, μ, σ, ϵ)
54+
return @. (x - μ) /+ ϵ)
5755
end
5856

5957
"""

0 commit comments

Comments
 (0)