Skip to content

Commit a015e5a

Browse files
RNNs redesign (#2500)
1 parent 0567cb3 commit a015e5a

File tree

19 files changed

+1069
-849
lines changed

19 files changed

+1069
-849
lines changed

NEWS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5+
## v0.15.0
6+
* Recurrent layers have undergone a complete redesign in [PR 2500](https://github.com/FluxML/Flux.jl/pull/2500).
7+
* `RNN`, `LSTM`, and `GRU` no longer store the hidden state internally. Instead, they now take the previous state as input and return the updated state as output.
8+
* These layers (`RNN`, `LSTM`, `GRU`) now process entire sequences at once, rather than one element at a time.
9+
* The `Recur` wrapper has been deprecated and removed.
10+
* The `reset!` function has also been removed; state management is now entirely up to the user.
11+
* `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing.
12+
513
## v0.14.22
614
* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl).
715

docs/src/guide/models/recurrence.md

Lines changed: 103 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ To introduce Flux's recurrence functionalities, we will consider the following v
66

77
![](../../assets/rnn-basic.png)
88

9-
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs.
9+
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step. It could be a timestamp or a word in a sentence encoded as vectors. `y1` to `y3` are their respective outputs.
1010

1111
An aspect to recognise is that in such a model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
1212

@@ -17,215 +17,152 @@ output_size = 5
1717
input_size = 2
1818
Wxh = randn(Float32, output_size, input_size)
1919
Whh = randn(Float32, output_size, output_size)
20-
b = randn(Float32, output_size)
20+
b = zeros(Float32, output_size)
2121

22-
function rnn_cell(h, x)
22+
function rnn_cell(x, h)
2323
h = tanh.(Wxh * x .+ Whh * h .+ b)
24-
return h, h
24+
return h
2525
end
2626

27-
x = rand(Float32, input_size) # dummy input data
28-
h = rand(Float32, output_size) # random initial hidden state
29-
30-
h, y = rnn_cell(h, x)
27+
seq_len = 3
28+
# dummy input data
29+
x = [rand(Float32, input_size) for i = 1:seq_len]
30+
# random initial hidden state
31+
h0 = zeros(Float32, output_size)
32+
33+
y = []
34+
ht = h0
35+
for xt in x
36+
ht = rnn_cell(xt, ht)
37+
y = [y; [ht]] # concatenate in non-mutating (AD friendly) way
38+
end
3139
```
3240

33-
Notice how the above is essentially a `Dense` layer that acts on two inputs, `h` and `x`.
41+
Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`.
3442

35-
If you run the last line a few times, you'll notice the output `y` changing slightly even though the input `x` is the same.
43+
The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model.
3644

3745
There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with:
3846

3947
```julia
4048
using Flux
4149

42-
rnn = Flux.RNNCell(2, 5)
50+
output_size = 5
51+
input_size = 2
52+
seq_len = 3
53+
x = [rand(Float32, input_size) for i = 1:seq_len]
54+
h0 = zeros(Float32, output_size)
4355

44-
x = rand(Float32, 2) # dummy data
45-
h = rand(Float32, 5) # initial hidden state
56+
rnn_cell = Flux.RNNCell(input_size => output_size)
4657

47-
h, y = rnn(h, x)
58+
y = []
59+
ht = h0
60+
for xt in x
61+
ht = rnn_cell(xt, ht)
62+
y = [y; [ht]]
63+
end
4864
```
65+
The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression.
4966

50-
## Stateful Models
67+
## Using a cell as part of a model
5168

52-
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
69+
Let's consider a simple model that is trained to predict a scalar quantity for each time step in a sequence. The model will have a single RNN cell, followed by a dense layer to produce the output.
70+
Since the [`RNNCell`](@ref) can deal with batches of data, we can define the model to accept an input where
71+
at each time step, the input is a matrix of size `(input_size, batch_size)`.
5372

5473
```julia
55-
x = rand(Float32, 2)
56-
h = rand(Float32, 5)
57-
58-
m = Flux.Recur(rnn, h)
59-
60-
y = m(x)
61-
```
62-
63-
The `Recur` wrapper stores the state between runs in the `m.state` field.
64-
65-
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
66-
67-
```jldoctest recurrence
68-
julia> using Flux
69-
70-
julia> RNN(2, 5) # or equivalently RNN(2 => 5)
71-
Recur(
72-
RNNCell(2 => 5, tanh), # 45 parameters
73-
) # Total: 4 trainable arrays, 45 parameters,
74-
# plus 1 non-trainable, 5 parameters, summarysize 404 bytes.
75-
```
76-
77-
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
78-
79-
Using these tools, we can now build the model shown in the above diagram with:
80-
81-
```jldoctest recurrence
82-
julia> m = Chain(RNN(2 => 5), Dense(5 => 1))
83-
Chain(
84-
Recur(
85-
RNNCell(2 => 5, tanh), # 45 parameters
86-
),
87-
Dense(5 => 1), # 6 parameters
88-
) # Total: 6 trainable arrays, 51 parameters,
89-
# plus 1 non-trainable, 5 parameters, summarysize 540 bytes.
90-
```
91-
In this example, each output has only one component.
92-
93-
## Working with sequences
94-
95-
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
96-
97-
```jldoctest recurrence; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
98-
julia> x = rand(Float32, 2);
99-
100-
julia> m(x)
101-
1-element Vector{Float32}:
102-
0.45860028
103-
```
104-
105-
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
106-
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2`
107-
since the model `m` has stored the state resulting from the `x1` step.
108-
109-
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by
110-
iterating the model on a sequence of data.
111-
112-
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
113-
114-
```jldoctest recurrence; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
115-
julia> x = [rand(Float32, 2) for i = 1:3];
116-
117-
julia> [m(xi) for xi in x]
118-
3-element Vector{Vector{Float32}}:
119-
[0.36080405]
120-
[-0.13914406]
121-
[0.9310162]
122-
```
123-
124-
!!! warning "Use of map and broadcast"
125-
Mapping and broadcasting operations with stateful layers such are discouraged,
126-
since the julia language doesn't guarantee a specific execution order.
127-
Therefore, avoid
128-
```julia
129-
y = m.(x)
130-
# or
131-
y = map(m, x)
132-
```
133-
and use explicit loops
134-
```julia
135-
y = [m(x) for x in x]
136-
```
137-
138-
If for some reason one wants to exclude the first step of the RNN chain for the computation of the loss, that can be handled with:
74+
struct RecurrentCellModel{H,C,D}
75+
h0::H
76+
cell::C
77+
dense::D
78+
end
13979

140-
```julia
141-
using Flux.Losses: mse
80+
# we choose to not train the initial hidden state
81+
Flux.@layer RecurrentCellModel trainable=(cell,dense)
14282

143-
function loss(x, y)
144-
m(x[1]) # ignores the output but updates the hidden states
145-
sum(mse(m(xi), yi) for (xi, yi) in zip(x[2:end], y))
83+
function RecurrentCellModel(input_size::Int, hidden_size::Int)
84+
return RecurrentCellModel(
85+
zeros(Float32, hidden_size),
86+
RNNCell(input_size => hidden_size),
87+
Dense(hidden_size => 1))
14688
end
14789

148-
y = [rand(Float32, 1) for i=1:2]
149-
loss(x, y)
90+
function (m::RecurrentCellModel)(x)
91+
z = []
92+
ht = m.h0
93+
for xt in x
94+
ht = m.cell(xt, ht)
95+
z = [z; [ht]]
96+
end
97+
z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
98+
= m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
99+
return
100+
end
150101
```
151102

152-
In such a model, only the last two outputs are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
103+
Notice that we stack the hidden states `z` to form a tensor of size `(hidden_size, seq_len, batch_size)`. This can speedup the final classification, since we then process all the outputs at once with a single forward pass of the dense layer.
153104

154-
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:
105+
Let's now define the training loop for this model:
155106

156107
```julia
157-
function loss(m, x, y)
158-
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
159-
end
160-
161-
seq_init = [rand(Float32, 2)]
162-
seq_1 = [rand(Float32, 2) for i = 1:3]
163-
seq_2 = [rand(Float32, 2) for i = 1:3]
108+
using Optimisers: AdamW
164109

165-
y1 = [rand(Float32, 1) for i = 1:3]
166-
y2 = [rand(Float32, 1) for i = 1:3]
110+
function loss(model, x, y)
111+
= model(x)
112+
y = stack(y, dims=2)
113+
return Flux.mse(ŷ, y)
114+
end
167115

168-
X = [seq_1, seq_2]
169-
Y = [y1, y2]
170-
data = zip(X,Y)
116+
# create dummy data
117+
seq_len, batch_size, input_size = 3, 4, 2
118+
x = [rand(Float32, input_size, batch_size) for _ = 1:seq_len]
119+
y = [rand(Float32, 1, batch_size) for _ = 1:seq_len]
171120

172-
Flux.reset!(m)
173-
[m(x) for x in seq_init]
121+
# initialize the model and optimizer
122+
model = RecurrentCellModel(input_size, 5)
123+
opt_state = Flux.setup(AdamW(1e-3), model)
174124

175-
opt = Flux.setup(Adam(1e-3), m)
176-
Flux.train!(loss, m, data, opt)
125+
# compute the gradient and update the model
126+
g = gradient(m -> loss(m, x, y),model)[1]
127+
Flux.update!(opt_state, model, g)
177128
```
178129

179-
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.
130+
## Handling the whole sequence at once
180131

181-
In this scenario, it is important to note that a single continuous sequence is considered. Since the model state is not reset between the 2 batches, the state of the model flows through the batches, which only makes sense in the context where `seq_1` is the continuation of `seq_init` and so on.
132+
In the above example, we processed the sequence one time step at a time using a recurrent cell. However, it is possible to process the entire sequence at once. This can be done by stacking the input data `x` to form a tensor of size `(input_size, seq_len)` or `(input_size, seq_len, batch_size)`.
133+
One can then use the [`RNN`](@ref), [`LSTM`](@ref) or [`GRU`](@ref) layers to process the entire input tensor.
182134

183-
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
135+
Let's consider the same example as above, but this time we use an `RNN` layer instead of an `RNNCell`:
184136

185137
```julia
186-
x = [rand(Float32, 2, 4) for i = 1:3]
187-
y = [rand(Float32, 1, 4) for i = 1:3]
188-
```
189-
190-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix). We do not need to use `Flux.reset!(m)` here; each sentence in the batch will output in its own "column", and the outputs of the different sentences won't mix.
191-
192-
To illustrate, we go through an example of batching with our implementation of `rnn_cell`. The implementation doesn't need to change; the batching comes for "free" from the way Julia does broadcasting and the rules of matrix multiplication.
193-
194-
```julia
195-
output_size = 5
196-
input_size = 2
197-
Wxh = randn(Float32, output_size, input_size)
198-
Whh = randn(Float32, output_size, output_size)
199-
b = randn(Float32, output_size)
200-
201-
function rnn_cell(h, x)
202-
h = tanh.(Wxh * x .+ Whh * h .+ b)
203-
return h, h
138+
struct RecurrentModel{H,C,D}
139+
h0::H
140+
rnn::C
141+
dense::D
204142
end
205-
```
206143

207-
Here, we use the last dimension of the input and the hidden state as the batch dimension. I.e., `h[:, n]` would be the hidden state of the nth sentence in the batch.
144+
Flux.@layer RecurrentModel trainable=(rnn, dense)
208145

209-
```julia
210-
batch_size = 4
211-
x = rand(Float32, input_size, batch_size) # dummy input data
212-
h = rand(Float32, output_size, batch_size) # random initial hidden state
146+
function RecurrentModel(input_size::Int, hidden_size::Int)
147+
return RecurrentModel(
148+
zeros(Float32, hidden_size),
149+
RNN(input_size => hidden_size),
150+
Dense(hidden_size => 1))
151+
end
213152

214-
h, y = rnn_cell(h, x)
215-
```
153+
function (m::RecurrentModel)(x)
154+
z = m.rnn(x, m.h0) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
155+
= m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
156+
return
157+
end
216158

217-
```julia
218-
julia> size(h) == size(y) == (output_size, batch_size)
219-
true
220-
```
159+
seq_len, batch_size, input_size = 3, 4, 2
160+
x = rand(Float32, input_size, seq_len, batch_size)
161+
y = rand(Float32, 1, seq_len, batch_size)
221162

222-
In many situations, such as when dealing with a language model, the sentences in each batch are independent (i.e. the last item of the first sentence of the first batch is independent from the first item of the first sentence of the second batch), so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situations, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
163+
model = RecurrentModel(input_size, 5)
164+
opt_state = Flux.setup(AdamW(1e-3), model)
223165

224-
```julia
225-
function loss(x, y)
226-
Flux.reset!(m)
227-
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
228-
end
166+
g = gradient(m -> Flux.mse(m(x), y), model)[1]
167+
Flux.update!(opt_state, model, g)
229168
```
230-
231-
A potential source of ambiguity with RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.

docs/src/reference/models/layers.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ RNN
112112
LSTM
113113
GRU
114114
GRUv3
115-
Flux.Recur
116-
Flux.reset!
117115
```
118116

119117
## Normalisation & Regularisation

perf/recurrent.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@ Flux.@functor RNNWrapper
77

88
# Need to specialize for RNNWrapper.
99
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
10-
Flux.reset!(r.rnn)
1110
[r.rnn(x) for x in X]
1211
end
1312

1413
fw(r::RNNWrapper, X) = begin
15-
Flux.reset!(r.rnn)
1614
r.rnn(X)
1715
end
1816

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg
3535

3636
export Chain, Dense, Embedding, EmbeddingBag,
3737
Maxout, SkipConnection, Parallel, PairwiseFusion,
38+
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
3839
RNN, LSTM, GRU, GRUv3,
3940
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
4041
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,

src/deprecations.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,8 @@ const FluxMetalAdaptor = MetalDevice
170170
# where `loss_mxy` accepts the model as its first argument.
171171
# """
172172
# ))
173+
174+
function reset!(x)
175+
Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!)
176+
return x
177+
end

0 commit comments

Comments
 (0)