Skip to content

Commit 4d63265

Browse files
Merge pull request #1508 from SebastianM-C/outputsize
make `outputsize` work with symbolically wrapped lux models
2 parents 0f11ece + 7dfe7d9 commit 4d63265

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

Diff for: ext/SymbolicsLuxExt.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SymbolicsLuxExt
33
using Lux
44
using Symbolics
55
using Lux.LuxCore
6+
using Lux.Random: AbstractRNG, default_rng
67
using Symbolics.SymbolicUtils
78

89
@static if isdefined(Lux.NilSizePropagation, :recursively_nillify)
@@ -11,9 +12,13 @@ using Symbolics.SymbolicUtils
1112
end
1213
end
1314

15+
function LuxCore.outputsize(model::SymbolicUtils.BasicSymbolic{<:LuxCore.AbstractLuxLayer}, x::Symbolics.Arr, rng::AbstractRNG)
16+
LuxCore.outputsize(Symbolics.getdefaultval(model), x, rng)
17+
end
18+
1419
@register_array_symbolic LuxCore.stateless_apply(
1520
model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
16-
size = LuxCore.outputsize(model, Symbolics.wrap(x), LuxCore.Random.default_rng())
21+
size = LuxCore.outputsize(model, Symbolics.wrap(x), default_rng())
1722
eltype = Real
1823
end
1924

Diff for: test/extensions/lux.jl

+14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using ComponentArrays
1111
Symbolics.@variables sym_ca[1:length(ca)] = ca
1212
Symbolics.@variables sym_ps::typeof(ps) = ps
1313
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]
14+
Symbolics.@variables sym_model::typeof(model) = model
1415

1516
out_ref = LuxCore.stateless_apply(model, x, ps)
1617
@test out_ref isa Vector{Float32}
@@ -41,6 +42,12 @@ using ComponentArrays
4142
@test length(out) == 6
4243
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
4344
@test out_sub == out_ref
45+
46+
out = LuxCore.stateless_apply(sym_model, sym_x, sym_ca)
47+
@test out isa Symbolics.Arr
48+
@test length(out) == 6
49+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),)))
50+
@test out_sub == out_ref
4451
end
4552

4653
@testset "Chain" begin
@@ -53,6 +60,7 @@ end
5360
Symbolics.@variables sym_ca[1:length(ca)] = ca
5461
Symbolics.@variables sym_ps::typeof(ps) = ps
5562
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]
63+
Symbolics.@variables sym_model::typeof(model) = model
5664

5765
out_ref = LuxCore.stateless_apply(model, x, ps)
5866
@test out_ref isa Vector{Float32}
@@ -83,4 +91,10 @@ end
8391
@test length(out) == 3
8492
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
8593
@test out_sub == out_ref
94+
95+
out = LuxCore.stateless_apply(sym_model, sym_x, sym_ca)
96+
@test out isa Symbolics.Arr
97+
@test length(out) == 3
98+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),)))
99+
@test out_sub == out_ref
86100
end

0 commit comments

Comments
 (0)