Skip to content

Commit dcf4808

Browse files
author
jeremiedb
committed
make RNN/LSTM/GRU Cells immutable - FluxML#1089
check that CUDNN drop solves for too many wrappers - FluxML#1259
1 parent 759de3f commit dcf4808

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/layers/recurrent.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ flip(f, xs) = reverse(f.(reverse(xs)))
5959

6060
# Vanilla RNN
6161

62-
mutable struct RNNCell{F,A,V}
62+
struct RNNCell{F,A,V}
6363
σ::F
6464
Wi::A
6565
Wh::A
@@ -96,7 +96,7 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
9696

9797
# LSTM
9898

99-
mutable struct LSTMCell{A,V}
99+
struct LSTMCell{A,V}
100100
Wi::A
101101
Wh::A
102102
b::V
@@ -141,7 +141,7 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
141141

142142
# GRU
143143

144-
mutable struct GRUCell{A,V}
144+
struct GRUCell{A,V}
145145
Wi::A
146146
Wh::A
147147
b::V

test/rnn-test-jdb.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ using Flux
33
# using CUDA
44
using Statistics: mean
55

6+
################################################
7+
# Too many wrappers issue #1259
8+
################################################
9+
m = RNN(3,2) |> gpu
10+
x = CUDA.ones(3,2)
11+
gs = gradient(() -> sum(m(x)), params(m))
12+
gs[m.cell.Wi]
13+
614
########################
715
# RNN test gpu
816
########################

0 commit comments

Comments
 (0)