Skip to content

Commit 9b21e2c

Browse files
authored
Simplify trainable, functor and Parallel (#1862)
* simple functor Chain * simplify Maxout * fix show as a result * trainable always a NamedTuple * Parallel: delete trainable, call combiner once * fixup * fix tests for Flux.modules
1 parent 841afe7 commit 9b21e2c

File tree

9 files changed

+138
-94
lines changed

9 files changed

+138
-94
lines changed

src/deprecations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,3 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,
3636

3737

3838
# v0.13 deprecations
39-
@deprecate Maxout(layers::Tuple) Maxout(layers...)

src/layers/basic.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,30 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
2828
true
2929
```
3030
"""
31-
struct Chain{T}
31+
struct Chain{T<:Union{Tuple, NamedTuple}}
3232
layers::T
33-
Chain(xs...) = new{typeof(xs)}(xs)
34-
function Chain(; kw...)
35-
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
36-
isempty(kw) && return new{Tuple{}}(())
37-
new{typeof(values(kw))}(values(kw))
38-
end
33+
end
34+
35+
Chain(xs...) = Chain(xs)
36+
function Chain(; kw...)
37+
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
38+
isempty(kw) && return Chain(())
39+
Chain(values(kw))
3940
end
4041

4142
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
4243
Base.iterate, Base.lastindex, Base.keys
4344

44-
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
45+
@functor Chain
4546

4647
applychain(::Tuple{}, x) = x
4748
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
4849

4950
(c::Chain)(x) = applychain(Tuple(c.layers), x)
5051

51-
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
52-
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
53-
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
52+
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
53+
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
54+
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
5455

5556
function Base.show(io::IO, c::Chain)
5657
print(io, "Chain(")
@@ -246,29 +247,23 @@ julia> Flux.outputsize(m3, (5, 11))
246247
(7, 11)
247248
```
248249
"""
249-
struct Maxout{FS<:Tuple}
250-
over::FS
251-
Maxout(layers...) = new{typeof(layers)}(layers)
252-
end
253-
254-
function Maxout(f::Function, n_alts::Integer)
255-
over = Tuple(f() for _ in 1:n_alts)
256-
return Maxout(over...)
250+
struct Maxout{T<:Tuple}
251+
layers::T
257252
end
253+
Maxout(layers...) = Maxout(layers)
254+
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)
258255

259256
@functor Maxout
260257

261258
function (mo::Maxout)(input::AbstractArray)
262259
# Perhaps surprisingly, pairwise max broadcast is often faster,
263260
# even with Zygote. See #698 and #1794
264-
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
261+
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.layers)
265262
end
266263

267-
trainable(mo::Maxout) = mo.over
268-
269264
function Base.show(io::IO, mo::Maxout)
270265
print(io, "Maxout(")
271-
_show_layers(io, mo.over)
266+
_show_layers(io, mo.layers)
272267
print(io, ")")
273268
end
274269

@@ -415,8 +410,8 @@ end
415410
Create a `Parallel` layer that passes an input array to each path in
416411
`layers`, before reducing the output with `connection`.
417412
418-
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
419-
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
413+
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
414+
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
420415
421416
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
422417
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
@@ -451,7 +446,7 @@ julia> model2[:β] == model2[2]
451446
true
452447
```
453448
"""
454-
struct Parallel{F, T}
449+
struct Parallel{F, T<:Union{Tuple, NamedTuple}}
455450
connection::F
456451
layers::T
457452
end
@@ -461,25 +456,31 @@ function Parallel(connection; kw...)
461456
layers = NamedTuple(kw)
462457
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
463458
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
464-
elseif isempty(layers)
465-
Parallel(connection, ())
466459
end
460+
isempty(layers) && return Parallel(connection, ())
467461
Parallel(connection, layers)
468462
end
469463

470464
@functor Parallel
471465

472-
(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
473-
(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
466+
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
474467
(m::Parallel)(xs::Tuple) = m(xs...)
468+
function (m::Parallel)(xs...)
469+
nl = length(m.layers)
470+
nx = length(xs)
471+
if nl != nx
472+
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
473+
end
474+
m.connection(map(|>, xs, Tuple(m.layers))...)
475+
end
475476

476477
Base.getindex(m::Parallel, i) = m.layers[i]
477-
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)
478+
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
479+
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
480+
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
478481

479482
Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
480483

481-
trainable(m::Parallel) = (m.connection, m.layers...)
482-
483484
function Base.show(io::IO, m::Parallel)
484485
print(io, "Parallel(", m.connection, ", ")
485486
_show_layers(io, m.layers)

src/layers/normalise.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ function Dropout(p; dims=:, rng = rng_from_array())
8282
end
8383

8484
@functor Dropout
85-
86-
trainable(a::Dropout) = ()
85+
trainable(a::Dropout) = (;)
8786

8887
function (a::Dropout)(x)
8988
_isactive(a) || return x
@@ -122,8 +121,7 @@ AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
122121
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)
123122

124123
@functor AlphaDropout
125-
126-
trainable(a::AlphaDropout) = ()
124+
trainable(a::AlphaDropout) = (;)
127125

128126
function (a::AlphaDropout)(x::AbstractArray{T}) where T
129127
_isactive(a) || return x
@@ -301,7 +299,7 @@ function BatchNorm(chs::Int, λ=identity;
301299
end
302300

303301
@functor BatchNorm
304-
trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()
302+
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
305303

306304
function (BN::BatchNorm)(x)
307305
@assert size(x, ndims(x)-1) == BN.chs
@@ -377,7 +375,7 @@ function InstanceNorm(chs::Int, λ=identity;
377375
end
378376

379377
@functor InstanceNorm
380-
trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()
378+
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
381379

382380
function (l::InstanceNorm)(x)
383381
@assert ndims(x) > 2
@@ -439,7 +437,7 @@ mutable struct GroupNorm{F,V,N,W}
439437
end
440438

441439
@functor GroupNorm
442-
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
440+
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
443441

444442
function GroupNorm(chs::Int, G::Int, λ=identity;
445443
initβ=zeros32, initγ=ones32,

src/layers/recurrent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function (m::Recur)(x)
6565
end
6666

6767
@functor Recur
68-
trainable(a::Recur) = (a.cell,)
68+
trainable(a::Recur) = (; cell = a.cell)
6969

7070
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
7171

src/layers/show.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ for T in [
1414
end
1515

1616
function _big_show(io::IO, obj, indent::Int=0, name=nothing)
17-
children = trainable(obj)
17+
children = _show_children(obj)
1818
if all(_show_leaflike, children)
1919
_layer_show(io, obj, indent, name)
2020
else
@@ -48,6 +48,11 @@ _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
4848
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
4949
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
5050

51+
_show_children(x) = trainable(x) # except for layers which hide their Tuple:
52+
_show_children(c::Chain) = c.layers
53+
_show_children(m::Maxout) = m.layers
54+
_show_children(p::Parallel) = (p.connection, p.layers...)
55+
5156
for T in [
5257
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense,
5358
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,

src/utils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,15 +775,20 @@ Chain(
775775
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
776776
777777
julia> Flux.modules(m2)
778-
5-element Vector{Any}:
778+
7-element Vector{Any}:
779779
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable
780+
(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
780781
Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
782+
(Dense(784, 64), BatchNorm(64, relu))
781783
Dense(784, 64) # 50_240 parameters
782784
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
783785
Dense(64, 10) # 650 parameters
784786
785787
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
786788
L2 (generic function with 1 method)
789+
790+
julia> L2(m2) isa Float32
791+
true
787792
```
788793
"""
789794
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

test/layers/basic.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import Flux: activations
2525
@test m[:first] == m[1]
2626
@test m[1:2] == m
2727

28+
@test m == m
29+
@test m == fmap(identity, m) # does not forget names
30+
2831
@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name
2932
end
3033

@@ -202,14 +205,39 @@ import Flux: activations
202205
inputs = randn(10), randn(5), randn(4)
203206
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
204207
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
208+
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
209+
@test Parallel(+, sin, cos)(pi/2) 1
205210
end
206211

207212
@testset "named access" begin
208213
m = Parallel(hcat, one = Dense(10, 10), two = identity)
209214
@test m[1] == m[:one]
215+
@test m[1:2] == m
210216

211217
@test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names
212218
@test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity)
219+
220+
@test m == fmap(identity, m) # does not forget names
221+
222+
@test Parallel(vcat, x = log)(1) == [0]
223+
@test Parallel(vcat, log)(1) == [0]
224+
end
225+
226+
@testset "trivial cases" begin
227+
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
228+
@test Parallel(hcat)(1) == hcat()
229+
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
230+
end
231+
232+
@testset "connection is called once" begin
233+
CNT = Ref(0)
234+
f_cnt = (x...) -> (CNT[]+=1; +(x...))
235+
Parallel(f_cnt, sin, cos, tan)(1)
236+
@test CNT[] == 1
237+
Parallel(f_cnt, sin, cos, tan)(1,2,3)
238+
@test CNT[] == 2
239+
Parallel(f_cnt, sin)(1)
240+
@test CNT[] == 3
213241
end
214242

215243
# Ref https://github.com/FluxML/Flux.jl/issues/1673

test/runtests.jl

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,58 @@ using CUDA
88

99
Random.seed!(0)
1010

11-
@testset "Utils" begin
12-
include("utils.jl")
13-
end
11+
@testset verbose=true "Flux.jl" begin
1412

15-
@testset "Onehot" begin
16-
include("onehot.jl")
17-
end
13+
@testset "Utils" begin
14+
include("utils.jl")
15+
end
1816

19-
@testset "Optimise" begin
20-
include("optimise.jl")
21-
end
17+
@testset "Onehot" begin
18+
include("onehot.jl")
19+
end
2220

23-
@testset "Data" begin
24-
include("data.jl")
25-
end
21+
@testset "Optimise" begin
22+
include("optimise.jl")
23+
end
2624

27-
@testset "Losses" begin
28-
include("losses.jl")
29-
include("ctc.jl")
30-
CUDA.functional() && include("ctc-gpu.jl")
31-
end
25+
@testset "Data" begin
26+
include("data.jl")
27+
end
3228

33-
@testset "Layers" begin
34-
include("layers/basic.jl")
35-
include("layers/normalisation.jl")
36-
include("layers/stateless.jl")
37-
include("layers/recurrent.jl")
38-
include("layers/conv.jl")
39-
include("layers/upsample.jl")
40-
include("layers/show.jl")
41-
end
29+
@testset "Losses" begin
30+
include("losses.jl")
31+
include("ctc.jl")
32+
CUDA.functional() && include("ctc-gpu.jl")
33+
end
4234

43-
@testset "outputsize" begin
44-
using Flux: outputsize
45-
include("outputsize.jl")
46-
end
35+
@testset "Layers" begin
36+
include("layers/basic.jl")
37+
include("layers/normalisation.jl")
38+
include("layers/stateless.jl")
39+
include("layers/recurrent.jl")
40+
include("layers/conv.jl")
41+
include("layers/upsample.jl")
42+
include("layers/show.jl")
43+
end
4744

48-
@testset "CUDA" begin
49-
if CUDA.functional()
50-
include("cuda/runtests.jl")
51-
else
52-
@warn "CUDA unavailable, not testing GPU support"
45+
@testset "outputsize" begin
46+
using Flux: outputsize
47+
include("outputsize.jl")
48+
end
49+
50+
@testset "CUDA" begin
51+
if CUDA.functional()
52+
include("cuda/runtests.jl")
53+
else
54+
@warn "CUDA unavailable, not testing GPU support"
55+
end
5356
end
54-
end
5557

56-
@static if VERSION == v"1.6"
57-
using Documenter
58-
@testset "Docs" begin
59-
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
60-
doctest(Flux)
58+
@static if VERSION == v"1.6"
59+
using Documenter
60+
@testset "Docs" begin
61+
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
62+
doctest(Flux)
63+
end
6164
end
6265
end

0 commit comments

Comments
 (0)