Skip to content

Commit 2c0dcb2

Browse files
bors[bot]jeremiedb
andauthored
Merge #1473
1473: Fix RNN tests on GPU r=DhairyaLGandhi a=jeremiedb Fix for RNN on CUDA, as discussed in #1367 . Co-authored-by: jeremie.db <[email protected]>
2 parents 5483a12 + 0b147d8 commit 2c0dcb2

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

test/cuda/curnn.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Flux, CUDA, Test
2-
using Flux: pullback
32

43
@testset for R in [RNN, GRU, LSTM]
54
m = R(10, 5) |> gpu
@@ -9,7 +8,7 @@ using Flux: pullback
98
θ = gradient(() -> sum(m(x)), params(m))
109
@test x isa CuArray
1110
@test θ[m.cell.Wi] isa CuArray
12-
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
11+
@test collect(m̄[].cell.Wi) == collect(θ[m.cell.Wi])
1312
end
1413

1514
@testset "RNN" begin
@@ -34,9 +33,9 @@ end
3433
cum̄, cux̄ = cuback(gpu(ȳ))
3534

3635
@test collect(cux̄)
37-
@test_broken m̄[].cell[].Wi collect(cum̄[].cell[].Wi)
38-
@test_broken m̄[].cell[].Wh collect(cum̄[].cell[].Wh)
39-
@test_broken m̄[].cell[].b collect(cum̄[].cell[].b)
36+
@test m̄[].cell.Wi collect(cum̄[].cell.Wi)
37+
@test m̄[].cell.Wh collect(cum̄[].cell.Wh)
38+
@test m̄[].cell.b collect(cum̄[].cell.b)
4039
if m̄[].state isa Tuple
4140
for (x, cx) in zip(m̄[].state, cum̄[].state)
4241
@test x collect(cx)

0 commit comments

Comments
 (0)