Skip to content

Commit 8b53033

Browse files
bors[bot]jeremiedb
andauthored
Merge #1390
1390: RNN deprecations and naming fixes r=CarloLucibello a=jeremiedb Following discussion in #1367 This PR brings the disambiguation between initial state parameters named `state0` in the rnn cells with the state of the rnn chain named `state` in the `Recur` struct. Add getproperty with deprecation messages to access the legacy `h` and `c` in the rnn cells as well as the `init` field in `Recu` (which now points to `recur.cell.state0`). Include both 1D and 2D input dimensions to the basic BPTT test. Co-authored-by: jeremie.db <[email protected]>
2 parents 09764f8 + 70e4797 commit 8b53033

File tree

2 files changed

+77
-13
lines changed

2 files changed

+77
-13
lines changed

src/layers/recurrent.jl

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ mutable struct Recur{T,S}
3131
end
3232

3333
function (m::Recur)(xs...)
34-
h, y = m.cell(m.state, xs...)
35-
m.state = h
34+
m.state, y = m.cell(m.state, xs...)
3635
return y
3736
end
3837

@@ -51,9 +50,19 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
5150
rnn.state = hidden(rnn.cell)
5251
```
5352
"""
54-
reset!(m::Recur) = (m.state = m.cell.state)
53+
reset!(m::Recur) = (m.state = m.cell.state0)
5554
reset!(m) = foreach(reset!, functor(m)[1])
5655

56+
# TODO remove in v0.13
57+
function Base.getproperty(m::Recur, sym::Symbol)
58+
if sym === :init
59+
@warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
60+
return getfield(m.cell, :state0)
61+
else
62+
return getfield(m, sym)
63+
end
64+
end
65+
5766
flip(f, xs) = reverse(f.(reverse(xs)))
5867

5968
# Vanilla RNN
@@ -63,7 +72,7 @@ struct RNNCell{F,A,V,S}
6372
Wi::A
6473
Wh::A
6574
b::V
66-
state::S
75+
state0::S
6776
end
6877

6978
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
@@ -89,16 +98,26 @@ end
8998
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
9099
output fed back into the input each time step.
91100
"""
92-
Recur(m::RNNCell) = Recur(m, m.state)
101+
Recur(m::RNNCell) = Recur(m, m.state0)
93102
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
94103

104+
# TODO remove in v0.13
105+
function Base.getproperty(m::RNNCell, sym::Symbol)
106+
if sym === :h
107+
@warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
108+
return getfield(m, :state0)
109+
else
110+
return getfield(m, sym)
111+
end
112+
end
113+
95114
# LSTM
96115

97116
struct LSTMCell{A,V,S}
98117
Wi::A
99118
Wh::A
100119
b::V
101-
state::S
120+
state0::S
102121
end
103122

104123
function LSTMCell(in::Integer, out::Integer;
@@ -138,16 +157,29 @@ for a good overview of the internals.
138157
"""
139158
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
140159
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
141-
Recur(m::LSTMCell) = Recur(m, m.state)
160+
Recur(m::LSTMCell) = Recur(m, m.state0)
142161
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
143162

163+
# TODO remove in v0.13
164+
function Base.getproperty(m::LSTMCell, sym::Symbol)
165+
if sym === :h
166+
@warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
167+
return getfield(m, :state0)[1]
168+
elseif sym === :c
169+
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
170+
return getfield(m, :state0)[2]
171+
else
172+
return getfield(m, sym)
173+
end
174+
end
175+
144176
# GRU
145177

146178
struct GRUCell{A,V,S}
147179
Wi::A
148180
Wh::A
149181
b::V
150-
state::S
182+
state0::S
151183
end
152184

153185
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
@@ -159,7 +191,7 @@ function (m::GRUCell)(h, x)
159191
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
160192
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
161193
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
162-
h′ = (1 .- z).*.+ z.*h
194+
h′ = (1 .- z) .* .+ z .* h
163195
return h′, h′
164196
end
165197

@@ -177,9 +209,19 @@ RNN but generally exhibits a longer memory span over sequences.
177209
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
178210
for a good overview of the internals.
179211
"""
180-
Recur(m::GRUCell) = Recur(m, m.state)
212+
Recur(m::GRUCell) = Recur(m, m.state0)
181213
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
182214

215+
# TODO remove in v0.13
216+
function Base.getproperty(m::GRUCell, sym::Symbol)
217+
if sym === :h
218+
@warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
219+
return getfield(m, :state0)
220+
else
221+
return getfield(m, sym)
222+
end
223+
end
224+
183225
@adjoint function Broadcast.broadcasted(f::Recur, args...)
184226
Zygote.∇map(__context__, f, args...)
185227
end

test/layers/recurrent.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# Ref FluxML/Flux.jl#1209
1+
# Ref FluxML/Flux.jl#1209 1D input
22
@testset "BPTT" begin
3-
seq = [rand(Float32, (2,1)) for i = 1:3]
3+
seq = [rand(Float32, 2) for i = 1:3]
44
for r [RNN,]
55
rnn = r(2,3)
66
Flux.reset!(rnn)
@@ -11,11 +11,33 @@
1111
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
1212
tanh.(rnn.cell.Wi * seq[2] + Wh *
1313
tanh.(rnn.cell.Wi * seq[1] +
14-
Wh * rnn.cell.state
14+
Wh * rnn.cell.state0
1515
+ rnn.cell.b)
1616
+ rnn.cell.b)
1717
+ rnn.cell.b)),
1818
rnn.cell.Wh)
1919
@test grads_seq[rnn.cell.Wh] bptt[1]
2020
end
2121
end
22+
23+
# Ref FluxML/Flux.jl#1209 2D input
24+
@testset "BPTT" begin
25+
seq = [rand(Float32, (2,1)) for i = 1:3]
26+
for r [RNN,]
27+
rnn = r(2,3)
28+
Flux.reset!(rnn)
29+
grads_seq = gradient(Flux.params(rnn)) do
30+
sum(rnn.(seq)[3])
31+
end
32+
Flux.reset!(rnn);
33+
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
34+
tanh.(rnn.cell.Wi * seq[2] + Wh *
35+
tanh.(rnn.cell.Wi * seq[1] +
36+
Wh * rnn.cell.state0
37+
+ rnn.cell.b)
38+
+ rnn.cell.b)
39+
+ rnn.cell.b)),
40+
rnn.cell.Wh)
41+
@test grads_seq[rnn.cell.Wh] bptt[1]
42+
end
43+
end

0 commit comments

Comments
 (0)