Skip to content

Commit 7a0140e

Browse files
committed
test: workaround Enzyme warning
1 parent 677b2ac commit 7a0140e

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

lib/LuxLib/test/common_ops/activation_tests.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
@testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin
2+
using Enzyme
3+
24
rng = StableRNG(1234)
35

46
apply_act(f::F, x) where {F} = sum(abs2, f.(x))
5-
apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x)))
7+
function apply_act_fast(f::F, x) where {F}
8+
if Enzyme.within_autodiff()
9+
y = similar(x)
10+
y .= x
11+
return sum(abs2, fast_activation!!(f, y))
12+
end
13+
return sum(abs2, fast_activation!!(f, copy(x)))
14+
end
615
apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x))
716

817
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES

test/layers/basic_tests.jl

+14-6
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ end
242242
@testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin
243243
rng = StableRNG(12345)
244244

245+
skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : []
246+
245247
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
246248
@testset "SkipConnection recombinator" begin
247249
d = Dense(2 => 2)
@@ -255,7 +257,8 @@ end
255257

256258
@test size(layer(x, ps, st)[1]) == (3, 1)
257259
@jet layer(x, ps, st)
258-
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
260+
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
261+
skip_backends)
259262

260263
d = Dense(2 => 2)
261264
display(d)
@@ -268,7 +271,8 @@ end
268271

269272
@test size(layer(x, ps, st)[1]) == (3, 1)
270273
@jet layer(x, ps, st)
271-
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
274+
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
275+
skip_backends)
272276

273277
d = Dense(2 => 3)
274278
display(d)
@@ -281,7 +285,8 @@ end
281285

282286
@test size(layer(x, ps, st)[1]) == (5, 7, 11)
283287
@jet layer(x, ps, st)
284-
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
288+
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
289+
skip_backends)
285290
end
286291

287292
@testset "Two-streams zero sum" begin
@@ -296,7 +301,8 @@ end
296301

297302
@test LuxCore.outputsize(layer, (x, y), rng) == (3,)
298303
@jet layer((x, y), ps, st)
299-
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3)
304+
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3,
305+
skip_backends)
300306
end
301307

302308
@testset "Inner interactions" begin
@@ -307,7 +313,8 @@ end
307313

308314
@test size(layer(x, ps, st)[1]) == (3, 1)
309315
@jet layer(x, ps, st)
310-
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
316+
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
317+
skip_backends)
311318

312319
x = randn(Float32, 2, 1) |> aType
313320
layer = Bilinear(2 => 3)
@@ -316,7 +323,8 @@ end
316323

317324
@test size(layer(x, ps, st)[1]) == (3, 1)
318325
@jet layer(x, ps, st)
319-
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
326+
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
327+
skip_backends)
320328
end
321329
end
322330
end

0 commit comments

Comments
 (0)