Skip to content

Fixes to Recurrent models for informative type mismatch error & output Vector for Vector input #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/src/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To introduce Flux's recurrence functionalities, we will consider the following v

In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs.

An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a dense layer for example is that the cell A is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).

In the most basic RNN case, cell A could be defined by the following:

Expand Down Expand Up @@ -69,15 +69,15 @@ Recur(RNNCell(2, 5, tanh))

Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.

Using these tools, we can now build the model is the above diagram with:
Using these tools, we can now build the model shown in the above diagram with:

```julia
m = Chain(RNN(2, 5), Dense(5, 1), x -> reshape(x, :))
```

## Working with sequences

Using the previously defined `m` recurrent model, we can the apply it to a single step from our sequence:
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:

```julia
x = rand(Float32, 2)
Expand All @@ -86,7 +86,7 @@ julia> m(x)
0.028398542
```

The m(x) operation would be represented by `x1 -> A -> y1` in our diagram.
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2` since the model `m` has stored the state resulting from the `x1` step:

```julia
Expand All @@ -98,7 +98,7 @@ julia> m(x)

Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by broadcasting the model on a sequence of data.

To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of length = `seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.

```julia
x = [rand(Float32, 2) for i = 1:3]
Expand Down Expand Up @@ -170,4 +170,4 @@ function loss(x, y)
end
```

A potential source of ambiguity of RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.
A potential source of ambiguity with RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.
21 changes: 11 additions & 10 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ mutable struct Recur{T,S}
state::S
end

function (m::Recur)(xs...)
m.state, y = m.cell(m.state, xs...)
function (m::Recur)(x)
m.state, y = m.cell(m.state, x)
return y
end

Expand Down Expand Up @@ -80,10 +80,11 @@ end
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell)(h, x)
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
sz = size(x)
return h, reshape(h, :, sz[2:end]...)
end

@functor RNNCell
Expand Down Expand Up @@ -133,7 +134,7 @@ function LSTMCell(in::Integer, out::Integer;
return cell
end

function (m::LSTMCell)((h, c), x)
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))
Expand All @@ -142,7 +143,8 @@ function (m::LSTMCell)((h, c), x)
output = σ.(gate(g, o, 4))
c = forget .* c .+ input .* cell
h′ = output .* tanh.(c)
return (h′, c), h′
sz = size(x)
return (h′, c), reshape(h′, :, sz[2:end]...)
end

@functor LSTMCell
Expand All @@ -160,8 +162,6 @@ See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
Recur(m::LSTMCell) = Recur(m, m.state0)

# TODO remove in v0.13
Expand Down Expand Up @@ -193,14 +193,15 @@ end
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell)(h, x)
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z) .* h̃ .+ z .* h
return h′, h′
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
end

@functor GRUCell
Expand Down
44 changes: 34 additions & 10 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Ref FluxML/Flux.jl#1209 1D input
@testset "BPTT" begin
@testset "BPTT-1D" begin
seq = [rand(Float32, 2) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2,3)
rnn = r(2, 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
sum(rnn.(seq)[3])
end
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state0
Expand All @@ -17,20 +17,20 @@
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] ≈ bptt[1]
end
end
end

# Ref FluxML/Flux.jl#1209 2D input
@testset "BPTT" begin
seq = [rand(Float32, (2,1)) for i = 1:3]
@testset "BPTT-2D" begin
seq = [rand(Float32, (2, 1)) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2,3)
rnn = r(2, 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
sum(rnn.(seq)[3])
end
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state0
Expand All @@ -39,5 +39,29 @@ end
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] ≈ bptt[1]
end
end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM]
m1 = R(3, 5)
m2 = R(3, 5)
x1 = rand(Float32, 3)
x2 = rand(Float32,3,1)
Flux.reset!(m1)
Flux.reset!(m2)
@test size(m1(x1)) == (5,)
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
@test size(m2(x2)) == (5,1)
@test size(m2(x2)) == (5,1)
end
end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM]
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
end
end