Skip to content

Commit 74bd04b

Browse files
authored
Fix #2086 re @autosize (#2087)
* fix 2086 * Embedding, but not yet
1 parent 4c38c8a commit 74bd04b

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/outputsize.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,13 @@ function _makelazy(ex::Expr)
248248
n == 0 && return ex
249249
n == 1 && error("@autosize doesn't expect an underscore here: $ex")
250250
n == 2 && return :($LazyLayer($(string(ex)), $(_makefun(ex)), nothing))
251-
n > 2 && return Expr(ex.head, ex.args[1], map(_makelazy, ex.args[2:end])...)
251+
n > 2 && return Expr(ex.head, map(_makelazy, ex.args)...)
252252
end
253253
_makelazy(x) = x
254254

255255
function _underscoredepth(ex::Expr)
256256
# Meta.isexpr(ex, :tuple) && :_ in ex.args && return 10
257-
ex.head in (:call, :kw, :(->), :block) || return 0
257+
ex.head in (:call, :kw, :(->), :block, :parameters) || return 0
258258
ex.args[1] === :(=>) && ex.args[2] === :_ && return 1
259259
m = maximum(_underscoredepth, ex.args)
260260
m == 0 ? 0 : m+1
@@ -279,6 +279,7 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
279279
"""
280280
autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
281281
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
282+
autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1)
282283
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)
283284

284285
_replaceunderscore(e, s) = e === :_ ? s : e

test/outputsize.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,20 @@ end
174174

175175
m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
176176
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
177-
177+
178+
@test_broken begin # outputsize fails on Embedding
179+
m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last
180+
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
181+
end
182+
178183
m = @autosize (9,) Dense(_ => div(_,2))
179184
@test randn(9) |> m |> size == (4,)
180185

181186
m = @autosize (3,) Chain(one = Dense(_ => 4), two = softmax) # needs kw
182187
@test randn(3) |> m |> size == (4,)
188+
189+
m = @autosize (3,) Chain(; one = Dense(_ => 4), two = softmax) # needs parameters
190+
@test randn(3) |> m |> size == (4,)
183191

184192
m = @autosize (3, 45) Maxout(() -> Dense(_ => 6, tanh), 2) # needs ->, block
185193
@test randn(3, 45) |> m |> size == (6, 45)
@@ -222,6 +230,10 @@ end
222230
Dense(_ => 10),
223231
)
224232
@test randn(Float32, img..., 1, 32) |> m |> size == (10, 32)
233+
234+
# https://github.com/FluxML/Flux.jl/issues/2086
235+
m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false))
236+
@test randn(Float32, 3, 32) |> m |> size == (2, 32)
225237
end
226238

227239
@testset "LazyLayer" begin

0 commit comments

Comments
 (0)