File tree 1 file changed +4
-5
lines changed
1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change 1
1
using Flux, CUDA, Test
2
- using Flux: pullback
3
2
4
3
@testset for R in [RNN, GRU, LSTM]
5
4
m = R (10 , 5 ) |> gpu
@@ -9,7 +8,7 @@ using Flux: pullback
9
8
θ = gradient (() -> sum (m (x)), params (m))
10
9
@test x isa CuArray
11
10
@test θ[m. cell. Wi] isa CuArray
12
- @test_broken collect (m̄[]. cell[] . Wi) == collect (θ[m. cell. Wi])
11
+ @test collect (m̄[]. cell. Wi) == collect (θ[m. cell. Wi])
13
12
end
14
13
15
14
@testset " RNN" begin
34
33
cum̄, cux̄ = cuback (gpu (ȳ))
35
34
36
35
@test x̄ ≈ collect (cux̄)
37
- @test_broken m̄[]. cell[] . Wi ≈ collect (cum̄[]. cell[] . Wi)
38
- @test_broken m̄[]. cell[] . Wh ≈ collect (cum̄[]. cell[] . Wh)
39
- @test_broken m̄[]. cell[] . b ≈ collect (cum̄[]. cell[] . b)
36
+ @test m̄[]. cell. Wi ≈ collect (cum̄[]. cell. Wi)
37
+ @test m̄[]. cell. Wh ≈ collect (cum̄[]. cell. Wh)
38
+ @test m̄[]. cell. b ≈ collect (cum̄[]. cell. b)
40
39
if m̄[]. state isa Tuple
41
40
for (x, cx) in zip (m̄[]. state, cum̄[]. state)
42
41
@test x ≈ collect (cx)
You can’t perform that action at this time.
0 commit comments