@@ -31,8 +31,7 @@ mutable struct Recur{T,S}
31
31
end
32
32
33
33
function (m:: Recur )(xs... )
34
- h, y = m. cell (m. state, xs... )
35
- m. state = h
34
+ m. state, y = m. cell (m. state, xs... )
36
35
return y
37
36
end
38
37
@@ -51,9 +50,19 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
51
50
rnn.state = hidden(rnn.cell)
52
51
```
53
52
"""
54
- reset! (m:: Recur ) = (m. state = m. cell. state )
53
+ reset! (m:: Recur ) = (m. state = m. cell. state0 )
55
54
reset! (m) = foreach (reset!, functor (m)[1 ])
56
55
56
+ # TODO remove in v0.13
57
+ function Base. getproperty (m:: Recur , sym:: Symbol )
58
+ if sym === :init
59
+ @warn " Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
60
+ return getfield (m. cell, :state0 )
61
+ else
62
+ return getfield (m, sym)
63
+ end
64
+ end
65
+
57
66
flip (f, xs) = reverse (f .(reverse (xs)))
58
67
59
68
# Vanilla RNN
@@ -63,7 +72,7 @@ struct RNNCell{F,A,V,S}
63
72
Wi:: A
64
73
Wh:: A
65
74
b:: V
66
- state :: S
75
+ state0 :: S
67
76
end
68
77
69
78
RNNCell (in:: Integer , out:: Integer , σ= tanh; init= Flux. glorot_uniform, initb= zeros, init_state= zeros) =
89
98
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
90
99
output fed back into the input each time step.
91
100
"""
92
- Recur (m:: RNNCell ) = Recur (m, m. state )
101
+ Recur (m:: RNNCell ) = Recur (m, m. state0 )
93
102
RNN (a... ; ka... ) = Recur (RNNCell (a... ; ka... ))
94
103
104
+ # TODO remove in v0.13
105
+ function Base. getproperty (m:: RNNCell , sym:: Symbol )
106
+ if sym === :h
107
+ @warn " RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
108
+ return getfield (m, :state0 )
109
+ else
110
+ return getfield (m, sym)
111
+ end
112
+ end
113
+
95
114
# LSTM
96
115
97
116
struct LSTMCell{A,V,S}
98
117
Wi:: A
99
118
Wh:: A
100
119
b:: V
101
- state :: S
120
+ state0 :: S
102
121
end
103
122
104
123
function LSTMCell (in:: Integer , out:: Integer ;
@@ -138,16 +157,29 @@ for a good overview of the internals.
138
157
"""
139
158
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
140
159
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
141
- Recur (m:: LSTMCell ) = Recur (m, m. state )
160
+ Recur (m:: LSTMCell ) = Recur (m, m. state0 )
142
161
LSTM (a... ; ka... ) = Recur (LSTMCell (a... ; ka... ))
143
162
163
+ # TODO remove in v0.13
164
+ function Base. getproperty (m:: LSTMCell , sym:: Symbol )
165
+ if sym === :h
166
+ @warn " LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
167
+ return getfield (m, :state0 )[1 ]
168
+ elseif sym === :c
169
+ @warn " LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
170
+ return getfield (m, :state0 )[2 ]
171
+ else
172
+ return getfield (m, sym)
173
+ end
174
+ end
175
+
144
176
# GRU
145
177
146
178
struct GRUCell{A,V,S}
147
179
Wi:: A
148
180
Wh:: A
149
181
b:: V
150
- state :: S
182
+ state0 :: S
151
183
end
152
184
153
185
GRUCell (in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
@@ -159,7 +191,7 @@ function (m::GRUCell)(h, x)
159
191
r = σ .(gate (gx, o, 1 ) .+ gate (gh, o, 1 ) .+ gate (b, o, 1 ))
160
192
z = σ .(gate (gx, o, 2 ) .+ gate (gh, o, 2 ) .+ gate (b, o, 2 ))
161
193
h̃ = tanh .(gate (gx, o, 3 ) .+ r .* gate (gh, o, 3 ) .+ gate (b, o, 3 ))
162
- h′ = (1 .- z). * h̃ .+ z.* h
194
+ h′ = (1 .- z) .* h̃ .+ z .* h
163
195
return h′, h′
164
196
end
165
197
@@ -177,9 +209,19 @@ RNN but generally exhibits a longer memory span over sequences.
177
209
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
178
210
for a good overview of the internals.
179
211
"""
180
- Recur (m:: GRUCell ) = Recur (m, m. state )
212
+ Recur (m:: GRUCell ) = Recur (m, m. state0 )
181
213
GRU (a... ; ka... ) = Recur (GRUCell (a... ; ka... ))
182
214
215
+ # TODO remove in v0.13
216
+ function Base. getproperty (m:: GRUCell , sym:: Symbol )
217
+ if sym === :h
218
+ @warn " GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
219
+ return getfield (m, :state0 )
220
+ else
221
+ return getfield (m, sym)
222
+ end
223
+ end
224
+
183
225
@adjoint function Broadcast. broadcasted (f:: Recur , args... )
184
226
Zygote.∇map (__context__, f, args... )
185
227
end
0 commit comments