Skip to content

Commit 0e0e2e7

Browse files
bors[bot]jeremiedbjeremiedb
authored
Merge #1367
1367: RNN update to drop CUDNN, fix LSTM bug and output type stability r=CarloLucibello a=jeremiedb PR related to #1114 #1360 #1365 Some experiment for RNN handling. Hidden state of each cell structure was dropped as they weren't needed (AFAIK, only needed for size inference for CUDNN, but bias size could be used as a substitute to cells' `h` there as well). Looked to drop dependence on CUDNN entirely, so it's a pure Flux/CUDA.jl. File `src/cuda/curnnjl` no longer used. No modifications were made to the cell computations. Initial test seems to show decent performance, but yet to benchmark. Pending issue: despite having dropped completely the CUDNN dependency, there's still an instability issue that seems present when running on GPU. This is illustrated in the test at lines 1-50 of file `test\rnn-test-jdb.jl`. If that test runs on CPU, it goes well thorugh the 100 iterations. However, the same on GPU will thow NAs after couple dozens of iterations. My only hypothesis so far: when performing the iteration over the sequence through `m.(x)` or `map(rnn, x)`, is the order of the execution safe? Ie: is it possible that there isn't a `sync()` on the CUDA side between those seq steps, which may mess up the state? ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: jeremiedb <[email protected]> Co-authored-by: jeremie.db <[email protected]>
2 parents 8fb94be + b5c3b6f commit 0e0e2e7

File tree

7 files changed

+40
-132
lines changed

7 files changed

+40
-132
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
docs/build/
66
docs/site/
77
deps
8-
# Manifest.toml
8+
.vscode
9+
# Manifest.toml

src/cuda/cuda.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module CUDAint
22

33
using ..CUDA
4-
54
using CUDA: CUDNN
6-
include("curnn.jl")
5+
6+
import ..Flux: Flux
7+
import Zygote
8+
using Zygote: @adjoint
9+
710
include("cudnn.jl")
811

912
end

src/cuda/curnn.jl

Lines changed: 0 additions & 89 deletions
This file was deleted.

src/layers/recurrent.jl

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
gate(h, n) = (1:h) .+ h*(n-1)
23
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
34
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
@@ -24,21 +25,19 @@ rnn.(1:10) # apply to a sequence
2425
rnn.state # 60
2526
```
2627
"""
27-
mutable struct Recur{T}
28+
mutable struct Recur{T,S}
2829
cell::T
29-
init
30-
state
30+
state::S
3131
end
3232

33-
Recur(m, h = hidden(m)) = Recur(m, h, h)
34-
3533
function (m::Recur)(xs...)
3634
h, y = m.cell(m.state, xs...)
3735
m.state = h
3836
return y
3937
end
4038

41-
@functor Recur cell, init
39+
@functor Recur
40+
trainable(a::Recur) = (a.cell,)
4241

4342
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
4443

@@ -52,34 +51,30 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
5251
rnn.state = hidden(rnn.cell)
5352
```
5453
"""
55-
reset!(m::Recur) = (m.state = m.init)
54+
reset!(m::Recur) = (m.state = m.cell.state)
5655
reset!(m) = foreach(reset!, functor(m)[1])
5756

5857
flip(f, xs) = reverse(f.(reverse(xs)))
5958

6059
# Vanilla RNN
6160

62-
mutable struct RNNCell{F,A,V}
61+
struct RNNCell{F,A,V,S}
6362
σ::F
6463
Wi::A
6564
Wh::A
6665
b::V
67-
h::V
66+
state::S
6867
end
6968

70-
RNNCell(in::Integer, out::Integer, σ = tanh;
71-
init = glorot_uniform) =
72-
RNNCell(σ, init(out, in), init(out, out),
73-
init(out), zeros(out))
69+
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
70+
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
7471

7572
function (m::RNNCell)(h, x)
7673
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
7774
h = σ.(Wi*x .+ Wh*h .+ b)
7875
return h, h
7976
end
8077

81-
hidden(m::RNNCell) = m.h
82-
8378
@functor RNNCell
8479

8580
function Base.show(io::IO, l::RNNCell)
@@ -94,22 +89,23 @@ end
9489
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
9590
output fed back into the input each time step.
9691
"""
92+
Recur(m::RNNCell) = Recur(m, m.state)
9793
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
9894

9995
# LSTM
10096

101-
mutable struct LSTMCell{A,V}
97+
struct LSTMCell{A,V,S}
10298
Wi::A
10399
Wh::A
104100
b::V
105-
h::V
106-
c::V
101+
state::S
107102
end
108103

109104
function LSTMCell(in::Integer, out::Integer;
110-
init = glorot_uniform)
111-
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
112-
zeros(out), zeros(out))
105+
init = glorot_uniform,
106+
initb = zeros,
107+
init_state = zeros)
108+
cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1)))
113109
cell.b[gate(out, 2)] .= 1
114110
return cell
115111
end
@@ -126,8 +122,6 @@ function (m::LSTMCell)((h, c), x)
126122
return (h′, c), h′
127123
end
128124

129-
hidden(m::LSTMCell) = (m.h, m.c)
130-
131125
@functor LSTMCell
132126

133127
Base.show(io::IO, l::LSTMCell) =
@@ -142,20 +136,22 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span
142136
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
143137
for a good overview of the internals.
144138
"""
139+
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
140+
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
141+
Recur(m::LSTMCell) = Recur(m, m.state)
145142
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
146143

147144
# GRU
148145

149-
mutable struct GRUCell{A,V}
146+
struct GRUCell{A,V,S}
150147
Wi::A
151148
Wh::A
152149
b::V
153-
h::V
150+
state::S
154151
end
155152

156-
GRUCell(in, out; init = glorot_uniform) =
157-
GRUCell(init(out * 3, in), init(out * 3, out),
158-
init(out * 3), zeros(out))
153+
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
154+
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
159155

160156
function (m::GRUCell)(h, x)
161157
b, o = m.b, size(h, 1)
@@ -167,8 +163,6 @@ function (m::GRUCell)(h, x)
167163
return h′, h′
168164
end
169165

170-
hidden(m::GRUCell) = m.h
171-
172166
@functor GRUCell
173167

174168
Base.show(io::IO, l::GRUCell) =
@@ -183,6 +177,7 @@ RNN but generally exhibits a longer memory span over sequences.
183177
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
184178
for a good overview of the internals.
185179
"""
180+
Recur(m::GRUCell) = Recur(m, m.state)
186181
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
187182

188183
@adjoint function Broadcast.broadcasted(f::Recur, args...)

test/cuda/curnn.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Flux: pullback
88
Flux.reset!(m)
99
θ = gradient(() -> sum(m(x)), params(m))
1010
@test x isa CuArray
11-
@test_broken θ[m.cell.Wi] isa CuArray
11+
@test θ[m.cell.Wi] isa CuArray
1212
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
1313
end
1414

@@ -20,17 +20,15 @@ end
2020
Flux.reset!(rnn)
2121
Flux.reset!(curnn)
2222
x = batch_size == 1 ?
23-
rand(10) :
24-
rand(10, batch_size)
23+
rand(Float32, 10) :
24+
rand(Float32, 10, batch_size)
2525
cux = gpu(x)
2626

2727
y, back = pullback((r, x) -> r(x), rnn, x)
2828
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
2929

3030
@test y collect(cuy)
3131

32-
@test haskey(Flux.CUDAint.descs, curnn.cell)
33-
3432
= randn(size(y))
3533
m̄, x̄ = back(ȳ)
3634
cum̄, cux̄ = cuback(gpu(ȳ))

test/layers/recurrent.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Ref FluxML/Flux.jl#1209
22
@testset "BPTT" begin
3-
seq = [rand(2) for i = 1:3]
3+
seq = [rand(Float32, (2,1)) for i = 1:3]
44
for r [RNN,]
55
rnn = r(2,3)
66
Flux.reset!(rnn)
@@ -11,7 +11,7 @@
1111
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
1212
tanh.(rnn.cell.Wi * seq[2] + Wh *
1313
tanh.(rnn.cell.Wi * seq[1] +
14-
Wh * rnn.init
14+
Wh * rnn.cell.state
1515
+ rnn.cell.b)
1616
+ rnn.cell.b)
1717
+ rnn.cell.b)),

test/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ end
101101
m = Dense(10, 5)
102102
@test size.(params(m)) == [(5, 10), (5,)]
103103
m = RNN(10, 5)
104-
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
104+
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]
105105

106106
# Layer duplicated in same chain, params just once pls.
107107
c = Chain(m, m)
108-
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
108+
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]
109109

110110
# Self-referential array. Just want params, no stack overflow pls.
111111
r = Any[nothing,m]
112112
r[1] = r
113-
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
113+
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
114114
end
115115

116116
@testset "Basic Stacking" begin

0 commit comments

Comments
 (0)