You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -6,7 +6,7 @@ To introduce Flux's recurrence functionalities, we will consider the following v
6
6
7
7

8
8
9
-
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and`y1` to `y3` are their respective outputs.
9
+
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step. It could be a timestamp or a word in a sentence encoded as vectors.`y1` to `y3` are their respective outputs.
10
10
11
11
An aspect to recognise is that in such a model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
12
12
@@ -17,213 +17,152 @@ output_size = 5
17
17
input_size =2
18
18
Wxh =randn(Float32, output_size, input_size)
19
19
Whh =randn(Float32, output_size, output_size)
20
-
b =randn(Float32, output_size)
20
+
b =zeros(Float32, output_size)
21
21
22
-
functionrnn_cell(h, x)
22
+
functionrnn_cell(x, h)
23
23
h =tanh.(Wxh * x .+ Whh * h .+ b)
24
-
return h, h
24
+
return h
25
25
end
26
26
27
-
x =rand(Float32, input_size) # dummy input data
28
-
h =rand(Float32, output_size) # random initial hidden state
29
-
30
-
h, y =rnn_cell(h, x)
27
+
seq_len =3
28
+
# dummy input data
29
+
x = [rand(Float32, input_size) for i =1:seq_len]
30
+
# random initial hidden state
31
+
h0 =zeros(Float32, output_size)
32
+
33
+
y = []
34
+
ht = h0
35
+
for xt in x
36
+
ht =rnn_cell(xt, ht)
37
+
y = [y; [ht]] # concatenate in non-mutating (AD friendly) way
38
+
end
31
39
```
32
40
33
-
Notice how the above is essentially a `Dense` layer that acts on two inputs, `h` and `x`.
41
+
Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`.
34
42
35
-
If you run the last line a few times, you'll notice the output `y` changing slightly even though the input `x` is the same.
43
+
The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model.
36
44
37
45
There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with:
The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression.
49
66
50
-
## Stateful Models
67
+
## Using a cell as part of a model
51
68
52
-
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
69
+
Let's consider a simple model that is trained to predict a scalar quantity for each time step in a sequence. The model will have a single RNN cell, followed by a dense layer to produce the output.
70
+
Since the [`RNNCell`](@ref) can deal with batches of data, we can define the model to accept an input where
71
+
at each time step, the input is a matrix of size `(input_size, batch_size)`.
53
72
54
73
```julia
55
-
x =rand(Float32, 2)
56
-
h =rand(Float32, 5)
57
-
58
-
m = Flux.Recur(rnn, h)
59
-
60
-
y =m(x)
61
-
```
62
-
63
-
The `Recur` wrapper stores the state between runs in the `m.state` field.
64
-
65
-
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
66
-
67
-
```jldoctest recurrence
68
-
julia> using Flux
69
-
70
-
julia> RNN(2, 5) # or equivalently RNN(2 => 5)
71
-
Recur(
72
-
RNNCell(2 => 5, tanh), # 45 parameters
73
-
) # Total: 4 trainable arrays, 45 parameters,
74
-
# plus 1 non-trainable, 5 parameters, summarysize 404 bytes.
75
-
```
76
-
77
-
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
78
-
79
-
Using these tools, we can now build the model shown in the above diagram with:
80
-
81
-
```jldoctest recurrence
82
-
julia> m = Chain(RNN(2 => 5), Dense(5 => 1))
83
-
Chain(
84
-
Recur(
85
-
RNNCell(2 => 5, tanh), # 45 parameters
86
-
),
87
-
Dense(5 => 1), # 6 parameters
88
-
) # Total: 6 trainable arrays, 51 parameters,
89
-
# plus 1 non-trainable, 5 parameters, summarysize 540 bytes.
90
-
```
91
-
In this example, each output has only one component.
92
-
93
-
## Working with sequences
94
-
95
-
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
106
-
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2`
107
-
since the model `m` has stored the state resulting from the `x1` step.
108
-
109
-
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by
110
-
iterating the model on a sequence of data.
111
-
112
-
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
z =stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
98
+
ŷ = m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
99
+
return ŷ
100
+
end
150
101
```
151
102
152
-
In such a model, only the last two outputs are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
103
+
Notice that we stack the hidden states `z`to form a tensor of size `(hidden_size, seq_len, batch_size)`. This can speedup the final classification, since we then process all the outputs at once with a single forward pass of the dense layer.
153
104
154
-
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:
105
+
Let's now define the training loop for this model:
155
106
156
107
```julia
157
-
functionloss(m, x, y)
158
-
sum(mse(m(xi), yi) for (xi, yi) inzip(x, y))
159
-
end
160
-
161
-
seq_init = [rand(Float32, 2)]
162
-
seq_1 = [rand(Float32, 2) for i =1:3]
163
-
seq_2 = [rand(Float32, 2) for i =1:3]
108
+
using Optimisers: AdamW
164
109
165
-
y1 = [rand(Float32, 1) for i =1:3]
166
-
y2 = [rand(Float32, 1) for i =1:3]
110
+
functionloss(model, x, y)
111
+
ŷ =model(x)
112
+
y =stack(y, dims=2)
113
+
return Flux.mse(ŷ, y)
114
+
end
167
115
168
-
X = [seq_1, seq_2]
169
-
Y = [y1, y2]
170
-
data =zip(X,Y)
116
+
# create dummy data
117
+
seq_len, batch_size, input_size =3, 4, 2
118
+
x = [rand(Float32, input_size, batch_size) for _ =1:seq_len]
119
+
y = [rand(Float32, 1, batch_size) for _ =1:seq_len]
171
120
172
-
[m(x) for x in seq_init]
121
+
# initialize the model and optimizer
122
+
model =RecurrentCellModel(input_size, 5)
123
+
opt_state = Flux.setup(AdamW(1e-3), model)
173
124
174
-
opt = Flux.setup(Adam(1e-3), m)
175
-
Flux.train!(loss, m, data, opt)
125
+
# compute the gradient and update the model
126
+
g =gradient(m ->loss(m, x, y),model)[1]
127
+
Flux.update!(opt_state, model, g)
176
128
```
177
129
178
-
Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.
130
+
## Handling the whole sequence at once
179
131
180
-
In this scenario, it is important to note that a single continuous sequence is considered. Since the model state is not reset between the 2 batches, the state of the model flows through the batches, which only makes sense in the context where `seq_1` is the continuation of `seq_init` and so on.
132
+
In the above example, we processed the sequence one time step at a time using a recurrent cell. However, it is possible to process the entire sequence at once. This can be done by stacking the input data `x` to form a tensor of size `(input_size, seq_len)` or `(input_size, seq_len, batch_size)`.
133
+
One can then use the [`RNN`](@ref), [`LSTM`](@ref) or [`GRU`](@ref) layers to process the entire input tensor.
181
134
182
-
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
135
+
Let's consider the same example as above, but this time we use an `RNN` layer instead of an `RNNCell`:
183
136
184
137
```julia
185
-
x = [rand(Float32, 2, 4) for i =1:3]
186
-
y = [rand(Float32, 1, 4) for i =1:3]
187
-
```
188
-
189
-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix). Each sentence in the batch will output in its own "column", and the outputs of the different sentences won't mix.
190
-
191
-
To illustrate, we go through an example of batching with our implementation of `rnn_cell`. The implementation doesn't need to change; the batching comes for "free" from the way Julia does broadcasting and the rules of matrix multiplication.
192
-
193
-
```julia
194
-
output_size =5
195
-
input_size =2
196
-
Wxh =randn(Float32, output_size, input_size)
197
-
Whh =randn(Float32, output_size, output_size)
198
-
b =randn(Float32, output_size)
199
-
200
-
functionrnn_cell(h, x)
201
-
h =tanh.(Wxh * x .+ Whh * h .+ b)
202
-
return h, h
138
+
struct RecurrentModel{H,C,D}
139
+
h0::H
140
+
rnn::C
141
+
dense::D
203
142
end
204
-
```
205
143
206
-
Here, we use the last dimension of the input and the hidden state as the batch dimension. I.e., `h[:, n]` would be the hidden state of the nth sentence in the batch.
144
+
Flux.@layer RecurrentModel trainable=(rnn, dense)
207
145
208
-
```julia
209
-
batch_size =4
210
-
x =rand(Float32, input_size, batch_size) # dummy input data
211
-
h =rand(Float32, output_size, batch_size) # random initial hidden state
In many situations, such as when dealing with a language model, the sentences in each batch are independent (i.e. the last item of the first sentence of the first batch is independent from the first item of the first sentence of the second batch), so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situations, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
163
+
model =RecurrentModel(input_size, 5)
164
+
opt_state = Flux.setup(AdamW(1e-3), model)
222
165
223
-
```julia
224
-
functionloss(x, y)
225
-
sum(mse(m(xi), yi) for (xi, yi) inzip(x, y))
226
-
end
166
+
g =gradient(m -> Flux.mse(m(x), y), model)[1]
167
+
Flux.update!(opt_state, model, g)
227
168
```
228
-
229
-
A potential source of ambiguity with RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.
0 commit comments