@@ -11,6 +11,7 @@ using ComponentArrays
11
11
Symbolics. @variables sym_ca[1 : length (ca)] = ca
12
12
Symbolics. @variables sym_ps:: typeof (ps) = ps
13
13
Symbolics. @variables sym_x[1 : 5 ] = Float32[1 ,2 ,3 ,4 ,5 ]
14
+ Symbolics. @variables sym_model:: typeof (model) = model
14
15
15
16
out_ref = LuxCore. stateless_apply (model, x, ps)
16
17
@test out_ref isa Vector{Float32}
@@ -41,6 +42,12 @@ using ComponentArrays
41
42
@test length (out) == 6
42
43
out_sub = Symbolics. value .(Symbolics. substitute .(Symbolics. scalarize (out), (Dict (sym_x => x, sym_ca => ca),)))
43
44
@test out_sub == out_ref
45
+
46
+ out = LuxCore. stateless_apply (sym_model, sym_x, sym_ca)
47
+ @test out isa Symbolics. Arr
48
+ @test length (out) == 6
49
+ out_sub = Symbolics. value .(Symbolics. substitute .(Symbolics. scalarize (out), (Dict (sym_model => model, sym_x => x, sym_ca => ca),)))
50
+ @test out_sub == out_ref
44
51
end
45
52
46
53
@testset " Chain" begin
53
60
Symbolics. @variables sym_ca[1 : length (ca)] = ca
54
61
Symbolics. @variables sym_ps:: typeof (ps) = ps
55
62
Symbolics. @variables sym_x[1 : 5 ] = Float32[1 , 2 , 3 , 4 , 5 ]
63
+ Symbolics. @variables sym_model:: typeof (model) = model
56
64
57
65
out_ref = LuxCore. stateless_apply (model, x, ps)
58
66
@test out_ref isa Vector{Float32}
83
91
@test length (out) == 3
84
92
out_sub = Symbolics. value .(Symbolics. substitute .(Symbolics. scalarize (out), (Dict (sym_x => x, sym_ca => ca),)))
85
93
@test out_sub == out_ref
94
+
95
+ out = LuxCore. stateless_apply (sym_model, sym_x, sym_ca)
96
+ @test out isa Symbolics. Arr
97
+ @test length (out) == 3
98
+ out_sub = Symbolics. value .(Symbolics. substitute .(Symbolics. scalarize (out), (Dict (sym_model => model, sym_x => x, sym_ca => ca),)))
99
+ @test out_sub == out_ref
86
100
end
0 commit comments