|
| 1 | + |
| 2 | +import Flux: ChainRulesCore |
| 3 | +# Some experiments with chain to start removing the need for recur to be mutable. |
| 4 | +# As per the conversation in the recurrent network rework issue. |
| 5 | + |
| 6 | +# Main difference between this and the _applychain function is we return a new chain |
| 7 | +# with the internal state modified as well as the output of applying x to the chain. |
| 8 | +function apply(chain::Flux.Chain, x) |
| 9 | + layers, out = _apply(chain.layers, x) |
| 10 | + Flux.Chain(layers), out |
| 11 | +end |
| 12 | + |
| 13 | +function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS} |
| 14 | + layers, out = _apply(Tuple(layers), x) |
| 15 | + NamedTuple{NMS}(layers), out |
| 16 | +end |
| 17 | + |
| 18 | +function _scan(layers::AbstractVector, x) |
| 19 | + new_layers = typeof(layers)(undef, length(layers)) |
| 20 | + for (idx, f) in enumerate(layers) |
| 21 | + new_layers[idx], x = _apply(f, x) |
| 22 | + end |
| 23 | + new_layers, x |
| 24 | +end |
| 25 | + |
| 26 | +# Reverse rule for _scan |
| 27 | +# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl |
| 28 | +function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x) |
| 29 | + duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer |
| 30 | + out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input) |
| 31 | + end |
| 32 | + outs = map(first, duo) |
| 33 | + backs = map(last, duo) |
| 34 | + |
| 35 | + function _scan_pullback(dy) |
| 36 | + multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back |
| 37 | + dapply, dlayer, din = back(delta) |
| 38 | + return dapply, (dlayer, din) |
| 39 | + end |
| 40 | + layergrads = reverse(map(first, multi)) |
| 41 | + xgrad = last(multi[end]) |
| 42 | + return (ChainRulesCore.NoTangent(), layergrads, xgrad) |
| 43 | + end |
| 44 | + return (map(first, outs), last(outs[end])), _scan_pullback |
| 45 | +end |
| 46 | + |
| 47 | +function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times |
| 48 | + _scan(layers, x) |
| 49 | +end |
| 50 | + |
| 51 | +# Generated function returns a tuple of args and the last output of the network. |
| 52 | +@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N} |
| 53 | + x_symbols = vcat(:x, [gensym() for _ in 1:N]) |
| 54 | + l_symbols = [gensym() for _ in 1:N] |
| 55 | + calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N] |
| 56 | + push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end]))) |
| 57 | + Expr(:block, calls...) |
| 58 | +end |
| 59 | + |
| 60 | +_apply(layer, x) = layer, layer(x) |
| 61 | + |
| 62 | + |
0 commit comments