Skip to content

Commit 19a9205

Browse files
mkschlegToucheSir
andauthored
Testing new apply interface for Flux.Chain (#5)
Co-authored-by: Brian Chen <[email protected]>
1 parent d917e17 commit 19a9205

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

src/Fluxperimental.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ export Split, Join
88
include("train.jl")
99
export shinkansen!
1010

11+
12+
include("chain.jl")
13+
1114
include("compact.jl")
1215

1316
end # module Fluxperimental

src/chain.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
import Flux: ChainRulesCore
3+
# Some experiments with chain to start removing the need for recur to be mutable.
4+
# As per the conversation in the recurrent network rework issue.
5+
6+
# Main difference between this and the _applychain function is we return a new chain
7+
# with the internal state modified as well as the output of applying x to the chain.
8+
function apply(chain::Flux.Chain, x)
9+
layers, out = _apply(chain.layers, x)
10+
Flux.Chain(layers), out
11+
end
12+
13+
function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS}
14+
layers, out = _apply(Tuple(layers), x)
15+
NamedTuple{NMS}(layers), out
16+
end
17+
18+
function _scan(layers::AbstractVector, x)
19+
new_layers = typeof(layers)(undef, length(layers))
20+
for (idx, f) in enumerate(layers)
21+
new_layers[idx], x = _apply(f, x)
22+
end
23+
new_layers, x
24+
end
25+
26+
# Reverse rule for _scan
27+
# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl
28+
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x)
29+
duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer
30+
out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input)
31+
end
32+
outs = map(first, duo)
33+
backs = map(last, duo)
34+
35+
function _scan_pullback(dy)
36+
multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back
37+
dapply, dlayer, din = back(delta)
38+
return dapply, (dlayer, din)
39+
end
40+
layergrads = reverse(map(first, multi))
41+
xgrad = last(multi[end])
42+
return (ChainRulesCore.NoTangent(), layergrads, xgrad)
43+
end
44+
return (map(first, outs), last(outs[end])), _scan_pullback
45+
end
46+
47+
function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times
48+
_scan(layers, x)
49+
end
50+
51+
# Generated function returns a tuple of args and the last output of the network.
52+
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
53+
x_symbols = vcat(:x, [gensym() for _ in 1:N])
54+
l_symbols = [gensym() for _ in 1:N]
55+
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N]
56+
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))
57+
Expr(:block, calls...)
58+
end
59+
60+
_apply(layer, x) = layer, layer(x)
61+
62+

test/chain.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Checking if the two grad structures are equal. Simplifies tests below.
2+
function _grads_equal(grads1, grads2)
3+
if length(keys(grads1)) != length(keys(grads2))
4+
return false
5+
end
6+
ret = true
7+
for weights in keys(grads1)
8+
if grads1[weights] isa AbstractArray
9+
ret = ret && all(grads1[weights] .== grads2[weights])
10+
elseif isnothing(grads1[weights])
11+
ret = ret && isnothing(grads2[weights])
12+
else
13+
throw("Grad returned type $(typeof(grads1[weights]))")
14+
end
15+
end
16+
return ret
17+
end
18+
19+
@testset "Applying the Chain!" begin
20+
@testset "Forward pass" begin
21+
x = rand(Float32, 3, 1)
22+
l1 = Flux.Dense(3, 4)
23+
l2 = Flux.Dense(4, 1)
24+
truth = l2(l1(x))
25+
26+
t_c = Flux.Chain(l1, l2) # tuple Chain
27+
new_t_c, out = Fluxperimental.apply(t_c, x)
28+
@test new_t_c[1] === l1 && new_t_c[2] === l2
29+
@test all(out .== truth)
30+
31+
32+
nt_c = Flux.Chain(l1=l1, l2=l2) # namedtuple Chain
33+
new_nt_c, out = Fluxperimental.apply(nt_c, x)
34+
@test new_nt_c[:l1] === l1 && new_nt_c[:l2] === l2
35+
@test all(out .== truth)
36+
37+
38+
v_c = Flux.Chain([l1, l2]) # vector Chain
39+
new_v_c, out = Fluxperimental.apply(v_c, x)
40+
@test new_v_c.layers[1] === l1 && new_v_c.layers[2] === l2
41+
@test all(out .== truth)
42+
end # @testset "Forward Pass"
43+
44+
@testset "Backward pass" begin
45+
x = rand(Float32, 3, 1)
46+
l1 = Flux.Dense(3, 4)
47+
l2 = Flux.Dense(4, 1)
48+
49+
@test begin # Test Tuple Chain Gradients
50+
t_c = Flux.Chain(l1, l2) # tuple Chain
51+
grads_truth = Flux.gradient(Flux.params(t_c)) do
52+
sum(t_c(x))
53+
end
54+
55+
grads_tuple = Flux.gradient(Flux.params(t_c)) do
56+
sum(Fluxperimental.apply(t_c, x)[end])
57+
end
58+
59+
_grads_equal(grads_tuple, grads_truth)
60+
end
61+
62+
@test begin # Test Named Tuple's Gradients
63+
nt_c = Flux.Chain(l1=l1, l2=l2) # named tuple Chain
64+
grads_truth = Flux.gradient(Flux.params(nt_c)) do
65+
sum(nt_c(x))
66+
end
67+
68+
grads_tuple = Flux.gradient(Flux.params(nt_c)) do
69+
sum(Fluxperimental.apply(nt_c, x)[end])
70+
end
71+
72+
_grads_equal(grads_tuple, grads_truth)
73+
end
74+
75+
@test begin # Test Vector Gradient
76+
c = Flux.Chain([l1, l2]) # named tuple Chain
77+
grads_truth = Flux.gradient(Flux.params(c)) do
78+
sum(c(x))
79+
end
80+
81+
grads_tuple = Flux.gradient(Flux.params(c)) do
82+
sum(Fluxperimental.apply(c, x)[end])
83+
end
84+
85+
_grads_equal(grads_tuple, grads_truth)
86+
end
87+
end # @testset "Backward Pass"
88+
end # @testset "Applying the Chain!"

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@ using Flux, Fluxperimental
33

44
@testset "Fluxperimental.jl" begin
55
include("split_join.jl")
6+
7+
include("chain.jl")
8+
69
include("compact.jl")
10+
711
end

0 commit comments

Comments
 (0)