Skip to content

Commit fd22bc3

Browse files
committed
fix broken tests on CUDA RNN
1 parent 5483a12 commit fd22bc3

File tree

2 files changed

+289
-5
lines changed

2 files changed

+289
-5
lines changed

test/cuda/curnn.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Flux, CUDA, Test
2-
using Flux: pullback
32

43
@testset for R in [RNN, GRU, LSTM]
54
m = R(10, 5) |> gpu
@@ -9,7 +8,7 @@ using Flux: pullback
98
θ = gradient(() -> sum(m(x)), params(m))
109
@test x isa CuArray
1110
@test θ[m.cell.Wi] isa CuArray
12-
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
11+
@test collect(m̄[].cell.Wi) == collect(θ[m.cell.Wi])
1312
end
1413

1514
@testset "RNN" begin
@@ -34,9 +33,9 @@ end
3433
cum̄, cux̄ = cuback(gpu(ȳ))
3534

3635
@test collect(cux̄)
37-
@test_broken m̄[].cell[].Wi collect(cum̄[].cell[].Wi)
38-
@test_broken m̄[].cell[].Wh collect(cum̄[].cell[].Wh)
39-
@test_broken m̄[].cell[].b collect(cum̄[].cell[].b)
36+
@test m̄[].cell.Wi collect(cum̄[].cell.Wi)
37+
@test m̄[].cell.Wh collect(cum̄[].cell.Wh)
38+
@test m̄[].cell.b collect(cum̄[].cell.b)
4039
if m̄[].state isa Tuple
4140
for (x, cx) in zip(m̄[].state, cum̄[].state)
4241
@test x collect(cx)

test/rnn-demo.jl

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
using Revise
2+
using Flux
3+
using Flux: @functor
4+
import Flux: trainable
5+
using Statistics: mean
6+
using Random: seed!
7+
8+
9+
mutable struct Recur2{T,S}
10+
cell::T
11+
state::S
12+
end
13+
14+
# original definition
15+
# function (m::Recur2)(xs...)
16+
# m.state, y = m.cell(m.state, xs...)
17+
# return y
18+
# end
19+
20+
# new def
21+
function (m::Recur2)(xs...)
22+
m.state, y = m.cell(m.state, xs...)
23+
return y
24+
end
25+
26+
@functor Recur2
27+
trainable(a::Recur2) = (a.cell,)
28+
29+
#####################################
30+
# Basic test
31+
#####################################
32+
seed!(123)
33+
feat = 3
34+
h_size = 5
35+
seq_len = 7
36+
batch_size = 4
37+
38+
X = [rand(Float32, feat, batch_size) for i in 1:seq_len]
39+
Y = rand(Float32, batch_size, seq_len) ./ 10
40+
41+
cell = Flux.RNNCell(feat, h_size)
42+
rnn = Recur2(cell, cell.state0)
43+
44+
rnn(X[1])
45+
rnn.state
46+
rnn(X[1])
47+
48+
rnn.(X)
49+
50+
function fold_test_1(x, m)
51+
foldl((a, b) -> m(b), x)
52+
end
53+
fold_test_1(X, rnn)
54+
55+
rnn.(X)
56+
57+
function rnn2(x)
58+
# println((x))
59+
println("state: ", rnn.state)
60+
rnn(x)
61+
end
62+
function fold_test_2(x)
63+
foldl((a, b) -> rnn(b), x, init=x[1])
64+
end
65+
fold_test_2(X)
66+
rnn.state
67+
68+
function fold_cell_1(x, c)
69+
foldl((a, b) -> cell(a, b)[1], x, init=cell.state0)
70+
end
71+
fold_cell_1(X, cell)
72+
rnn.state
73+
74+
75+
f1(x) = begin
76+
println(x)
77+
x^2
78+
end
79+
80+
function fold_test_2(x)
81+
foldl((a, b) -> f1(b), x, init=5)
82+
end
83+
x1 = fold_test_2([2,3])
84+
85+
# rnn = Chain(
86+
# RNN(feat, h_size),
87+
# Dense(h_size, 1, σ),
88+
# x -> reshape(x, :))
89+
90+
91+
#### transfer to gpu ####
92+
rnn_gpu = rnn |> gpu
93+
X_gpu = gpu(X)
94+
Y_gpu = gpu(Y)
95+
96+
θ = Flux.params(rnn)
97+
θ_gpu = Flux.params(rnn_gpu)
98+
length(θ)
99+
length(θ_gpu)
100+
function loss(x, y)
101+
Flux.reset!(rnn)
102+
l = mean((Flux.stack(map(rnn, x), 2) .- y).^2)
103+
return l
104+
end
105+
function loss_gpu(x, y)
106+
Flux.reset!(rnn_gpu)
107+
l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2)
108+
return l
109+
end
110+
111+
opt = ADAM(1e-3)
112+
opt_gpu = ADAM(1e-3)
113+
for i in 1:5
114+
println("iter: ", i)
115+
Flux.train!(loss, θ, [(X, Y)], opt)
116+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
117+
println("loss_cpu: ", loss(X, Y))
118+
println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu))
119+
# println("θ[3][1:2]: ", θ[3][1:2])
120+
# println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2])
121+
# println("θ[4][1:2]: ", θ[4][1:2])
122+
# println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2])
123+
# println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2])
124+
# println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2])
125+
end
126+
127+
@code_warntype rnn(X[1])
128+
129+
function speed_cpu(n=10)
130+
for i in 1:n
131+
Flux.train!(loss, θ, [(X, Y)], opt)
132+
end
133+
return loss(X, Y)
134+
end
135+
136+
function speed_gpu(n=10)
137+
for i in 1:n
138+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
139+
end
140+
return loss_gpu(X_gpu, Y_gpu)
141+
end
142+
143+
@time speed_cpu(100)
144+
@time speed_gpu(100)
145+
146+
147+
#####################################
148+
# RNN vanilla
149+
#####################################
150+
seed!(123)
151+
feat = 32
152+
h_size = 64
153+
seq_len = 50
154+
batch_size = 256
155+
156+
rnn = Chain(
157+
RNN(feat, h_size),
158+
Dense(h_size, 1, σ),
159+
x -> reshape(x, :))
160+
161+
X = [rand(Float32, feat, batch_size) for i in 1:seq_len]
162+
Y = rand(Float32, batch_size, seq_len) ./ 10
163+
164+
#### transfer to gpu ####
165+
rnn_gpu = rnn |> gpu
166+
X_gpu = gpu(X)
167+
Y_gpu = gpu(Y)
168+
169+
θ = Flux.params(rnn)
170+
θ_gpu = Flux.params(rnn_gpu)
171+
length(θ)
172+
length(θ_gpu)
173+
function loss(x, y)
174+
Flux.reset!(rnn)
175+
l = mean((Flux.stack(map(rnn, x), 2) .- y).^2)
176+
return l
177+
end
178+
function loss_gpu(x, y)
179+
Flux.reset!(rnn_gpu)
180+
l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2)
181+
return l
182+
end
183+
184+
opt = ADAM(1e-3)
185+
opt_gpu = ADAM(1e-3)
186+
for i in 1:5
187+
println("iter: ", i)
188+
Flux.train!(loss, θ, [(X, Y)], opt)
189+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
190+
println("loss_cpu: ", loss(X, Y))
191+
println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu))
192+
# println("θ[3][1:2]: ", θ[3][1:2])
193+
# println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2])
194+
# println("θ[4][1:2]: ", θ[4][1:2])
195+
# println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2])
196+
# println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2])
197+
# println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2])
198+
end
199+
200+
@code_warntype rnn(X[1])
201+
202+
function speed_cpu(n=10)
203+
for i in 1:n
204+
Flux.train!(loss, θ, [(X, Y)], opt)
205+
end
206+
return loss(X, Y)
207+
end
208+
209+
function speed_gpu(n=10)
210+
for i in 1:n
211+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
212+
end
213+
return loss_gpu(X_gpu, Y_gpu)
214+
end
215+
216+
@time speed_cpu(100)
217+
@time speed_gpu(100)
218+
219+
#####################################
220+
# LSTM
221+
#####################################
222+
feat = 32
223+
h_size = 64
224+
seq_len = 50
225+
batch_size = 256
226+
227+
rnn = Chain(LSTM(feat, h_size),
228+
LSTM(h_size, h_size),
229+
LSTM(h_size, h_size),
230+
Dense(h_size, 1, σ),
231+
x -> reshape(x, :))
232+
233+
X = [rand(Float32, feat, batch_size) for i in 1:seq_len]
234+
Y = rand(Float32, batch_size, seq_len) ./ 10
235+
236+
#### transfer to gpu ####
237+
rnn_gpu = rnn |> gpu
238+
X_gpu = gpu(X)
239+
Y_gpu = gpu(Y)
240+
241+
θ = Flux.params(rnn)
242+
θ_gpu = Flux.params(rnn_gpu)
243+
function loss(x, y)
244+
Flux.reset!(rnn)
245+
l = mean((Flux.stack(map(rnn, x), 2) .- y).^2)
246+
return l
247+
end
248+
function loss_gpu(x, y)
249+
Flux.reset!(rnn_gpu)
250+
l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2)
251+
return l
252+
end
253+
254+
opt = ADAM(1e-3)
255+
opt_gpu = ADAM(1e-3)
256+
257+
for i in 1:5
258+
println("iter: ", i)
259+
Flux.train!(loss, θ, [(X, Y)], opt)
260+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
261+
println("loss_cpu: ", loss(X, Y))
262+
println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu))
263+
end
264+
265+
266+
function speed_cpu(n=10)
267+
for i in 1:n
268+
Flux.train!(loss, θ, [(X, Y)], opt)
269+
end
270+
return loss(X, Y)
271+
end
272+
273+
function speed_gpu(n=10)
274+
for i in 1:n
275+
Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)
276+
end
277+
return loss_gpu(X_gpu, Y_gpu)
278+
end
279+
280+
@code_warntype rnn(X[1])
281+
282+
using BenchmarkTools
283+
@time speed_cpu(100)
284+
@btime speed_gpu(100)
285+

0 commit comments

Comments
 (0)