Skip to content

Commit 98e7222

Browse files
bors[bot]Dhairya Gandhi
and
Dhairya Gandhi
authored
Merge #1358
1358: Fix BPTT by overriding stateful broadcast adjoint r=DhairyaLGandhi a=DhairyaLGandhi Fixes #1209 In this PR, we replace the regular broadcasting adjoint with that of the `map` equivalent which is better tested in terms of stateful cases. We ultimately will revert back to the broadacasting adjoint via FluxML/Zygote.jl#807 but this specialises the case for recurrent layers @oxinabox @ToucheSir Comments? ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Dhairya Gandhi <[email protected]>
2 parents 9ed04bb + f6f9925 commit 98e7222

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

src/layers/recurrent.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,7 @@ See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
184184
for a good overview of the internals.
185185
"""
186186
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
187+
188+
@adjoint function Broadcast.broadcasted(f::Recur, args...)
189+
Zygote.∇map(__context__, f, args...)
190+
end

test/layers/recurrent.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Ref FluxML/Flux.jl#1209
2+
@testset "BPTT" begin
3+
seq = [rand(2) for i = 1:3]
4+
for r [RNN,]
5+
rnn = r(2,3)
6+
Flux.reset!(rnn)
7+
grads_seq = gradient(Flux.params(rnn)) do
8+
sum(rnn.(seq)[3])
9+
end
10+
Flux.reset!(rnn);
11+
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
12+
tanh.(rnn.cell.Wi * seq[2] + Wh *
13+
tanh.(rnn.cell.Wi * seq[1] +
14+
Wh * rnn.init
15+
+ rnn.cell.b)
16+
+ rnn.cell.b)
17+
+ rnn.cell.b)),
18+
rnn.cell.Wh)
19+
@test grads_seq[rnn.cell.Wh] bptt[1]
20+
end
21+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ end
3030
include("layers/basic.jl")
3131
include("layers/normalisation.jl")
3232
include("layers/stateless.jl")
33+
include("layers/recurrent.jl")
3334
include("layers/conv.jl")
3435
end
3536

0 commit comments

Comments
 (0)