Skip to content

Commit f1dbc97

Browse files
bors[bot]mkschlegDhairyaLGandhi
authored
Merge #1686
1686: Adding support for folding RNNs over 3d arrays r=DhairyaLGandhi a=mkschleg From #1678, adding a Recur like interface for a folded operation with support for 3-dimensional arrays. This is how many users expect RNNs to work if they are familiar with Pytorch and Tensorflow, and there seems to be some desire for support for this feature as per the discussion in #1671 and `@jeremiedb` . This will also make a push to implementing support for the CuDNN versions of RNNs/GRUs/LSTMs more streamlined as this is the data layout that API expects. I did a barebones implementation to add support so we can start iterating on API. There are several questions that I have lingering with this interface: - ~Should we support different modes where we return all or only the last hidden state? Is there a better way to do the concat of the hidden states?~ - What kind of tests should we have? Just follow what we currently do for RNNs/LSTMs/GRUs? - ~For the CPU version, does it make sense not to specialize on the different rnn types? We might be able to take more advantage of BLAS if we specialized on say `Folded{GRU}`.~ - ~Do we want to force the temporal dimension to be the 2nd?~ - ~Do we want this to be stateful? (i.e. allow the user to change what the starting hidden state is rather than state0).~ ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Matthew Schlegel <[email protected]> Co-authored-by: Matthew Schlegel <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
2 parents 9a395b2 + d1b1daf commit f1dbc97

File tree

3 files changed

+77
-17
lines changed

3 files changed

+77
-17
lines changed

src/layers/recurrent.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ rnn.state # 5
2424
rnn.(1:10) # apply to a sequence
2525
rnn.state # 60
2626
```
27+
28+
Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:
29+
30+
```julia
31+
accum(h, x) = (h .+ x, x)
32+
rnn = Flux.Recur(accum, zeros(Int, 1, 1))
33+
rnn([2]) # 2
34+
rnn([3]) # 3
35+
rnn.state # 5
36+
rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
37+
rnn.state # 60
38+
```
39+
2740
"""
2841
mutable struct Recur{T,S}
2942
cell::T
@@ -53,6 +66,7 @@ rnn.state = hidden(rnn.cell)
5366
reset!(m::Recur) = (m.state = m.cell.state0)
5467
reset!(m) = foreach(reset!, functor(m)[1])
5568

69+
5670
# TODO remove in v0.13
5771
function Base.getproperty(m::Recur, sym::Symbol)
5872
if sym === :init
@@ -67,6 +81,12 @@ end
6781

6882
flip(f, xs) = reverse(f.(reverse(xs)))
6983

84+
function (m::Recur)(x::AbstractArray{T, 3}) where T
85+
h = [m(view(x, :, :, i)) for i in 1:size(x, 3)]
86+
sze = size(h[1])
87+
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
88+
end
89+
7090
# Vanilla RNN
7191

7292
struct RNNCell{F,A,V,S}

test/cuda/curnn.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ end
5353
y = (rnn(ohx); rnn(ohx))
5454

5555
cuy = (curnn(cuohx); curnn(cuohx))
56-
@test y collect(cuy)
56+
@test y collect(cuy)
57+
58+
Flux.reset!(rnn)
59+
Flux.reset!(curnn)
60+
fx = rand(Float32, 10, batch_size, 3)
61+
cufx = gpu(fx)
62+
fy = (rnn(fx); rnn(fx))
63+
64+
cufy = (curnn(cufx); curnn(cufx))
65+
@test fy collect(cufy)
5766
end
5867
end

test/layers/recurrent.jl

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,57 @@ end
4242
end
4343
end
4444

45+
@testset "BPTT-3D" begin
46+
seq = rand(Float32, (2, 1, 3))
47+
rnn = RNN(2, 3)
48+
Flux.reset!(rnn)
49+
grads_seq = gradient(Flux.params(rnn)) do
50+
sum(rnn(seq)[:, :, 3])
51+
end
52+
Flux.reset!(rnn);
53+
bptt = gradient(rnn.cell.Wh) do Wh
54+
# calculate state 1
55+
s1 = tanh.(rnn.cell.Wi * seq[:, :, 1] +
56+
Wh * rnn.cell.state0 +
57+
rnn.cell.b)
58+
#calculate state 2
59+
s2 = tanh.(rnn.cell.Wi * seq[:, :, 2] +
60+
Wh * s1 +
61+
rnn.cell.b)
62+
#calculate state 3
63+
s3 = tanh.(rnn.cell.Wi * seq[:, :, 3] +
64+
Wh * s2 +
65+
rnn.cell.b)
66+
sum(s3) # loss is sum of state 3
67+
end
68+
@test grads_seq[rnn.cell.Wh] bptt[1]
69+
end
70+
4571
@testset "RNN-shapes" begin
46-
@testset for R in [RNN, GRU, LSTM, GRUv3]
47-
m1 = R(3, 5)
48-
m2 = R(3, 5)
49-
x1 = rand(Float32, 3)
50-
x2 = rand(Float32,3,1)
51-
Flux.reset!(m1)
52-
Flux.reset!(m2)
53-
@test size(m1(x1)) == (5,)
54-
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
55-
@test size(m2(x2)) == (5,1)
56-
@test size(m2(x2)) == (5,1)
57-
end
72+
@testset for R in [RNN, GRU, LSTM, GRUv3]
73+
m1 = R(3, 5)
74+
m2 = R(3, 5)
75+
m3 = R(3, 5)
76+
x1 = rand(Float32, 3)
77+
x2 = rand(Float32, 3, 1)
78+
x3 = rand(Float32, 3, 1, 2)
79+
Flux.reset!(m1)
80+
Flux.reset!(m2)
81+
Flux.reset!(m3)
82+
@test size(m1(x1)) == (5,)
83+
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
84+
@test size(m2(x2)) == (5, 1)
85+
@test size(m2(x2)) == (5, 1)
86+
@test size(m3(x3)) == (5, 1, 2)
87+
@test size(m3(x3)) == (5, 1, 2)
88+
end
5889
end
5990

6091
@testset "RNN-input-state-eltypes" begin
6192
@testset for R in [RNN, GRU, LSTM, GRUv3]
62-
m = R(3, 5)
63-
x = rand(Float64, 3, 1)
64-
Flux.reset!(m)
65-
@test_throws MethodError m(x)
93+
m = R(3, 5)
94+
x = rand(Float64, 3, 1)
95+
Flux.reset!(m)
96+
@test_throws MethodError m(x)
6697
end
6798
end

0 commit comments

Comments
 (0)