Skip to content

Commit ad65754

Browse files
committed
Create BasicSymbolicImpl struct to separate metadata from hash consing
1 parent b1f111b commit ad65754

File tree

1 file changed

+63
-25
lines changed

1 file changed

+63
-25
lines changed

src/types.jl

+63-25
Original file line numberDiff line numberDiff line change
@@ -24,53 +24,59 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT)
2424
const ENABLE_HASHCONSING = Ref(true)
2525

2626
@compactify show_methods=false begin
27-
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
28-
metadata::Metadata = NO_METADATA
29-
end
30-
mutable struct Sym{T} <: BasicSymbolic{T}
27+
@abstract mutable struct BasicSymbolicImpl{T} end
28+
mutable struct Sym{T} <: BasicSymbolicImpl{T}
3129
name::Symbol = :OOF
3230
end
33-
mutable struct Term{T} <: BasicSymbolic{T}
31+
mutable struct Term{T} <: BasicSymbolicImpl{T}
3432
f::Any = identity # base/num if Pow; issorted if Add/Dict
3533
arguments::Vector{Any} = EMPTY_ARGS
3634
hash::RefValue{UInt} = EMPTY_HASH
3735
hash2::RefValue{UInt} = EMPTY_HASH
3836
end
39-
mutable struct Mul{T} <: BasicSymbolic{T}
37+
mutable struct Mul{T} <: BasicSymbolicImpl{T}
4038
coeff::Any = 0 # exp/den if Pow
4139
dict::EMPTY_DICT_T = EMPTY_DICT
4240
hash::RefValue{UInt} = EMPTY_HASH
4341
hash2::RefValue{UInt} = EMPTY_HASH
4442
arguments::Vector{Any} = EMPTY_ARGS
4543
issorted::RefValue{Bool} = NOT_SORTED
4644
end
47-
mutable struct Add{T} <: BasicSymbolic{T}
45+
mutable struct Add{T} <: BasicSymbolicImpl{T}
4846
coeff::Any = 0 # exp/den if Pow
4947
dict::EMPTY_DICT_T = EMPTY_DICT
5048
hash::RefValue{UInt} = EMPTY_HASH
5149
hash2::RefValue{UInt} = EMPTY_HASH
5250
arguments::Vector{Any} = EMPTY_ARGS
5351
issorted::RefValue{Bool} = NOT_SORTED
5452
end
55-
mutable struct Div{T} <: BasicSymbolic{T}
53+
mutable struct Div{T} <: BasicSymbolicImpl{T}
5654
num::Any = 1
5755
den::Any = 1
5856
simplified::Bool = false
5957
arguments::Vector{Any} = EMPTY_ARGS
6058
end
61-
mutable struct Pow{T} <: BasicSymbolic{T}
59+
mutable struct Pow{T} <: BasicSymbolicImpl{T}
6260
base::Any = 1
6361
exp::Any = 1
6462
arguments::Vector{Any} = EMPTY_ARGS
6563
end
6664
end
6765

66+
@kwdef struct BasicSymbolic{T} <: Symbolic{T}
67+
impl::BasicSymbolicImpl{T}
68+
metadata::Metadata = NO_METADATA
69+
end
70+
6871
function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
6972
ScalarSymbolic()
7073
end
7174

7275
function exprtype(x::BasicSymbolic)
73-
@compactified x::BasicSymbolic begin
76+
exprtype(x.impl)
77+
end
78+
function exprtype(impl::BasicSymbolicImpl)
79+
@compactified impl::BasicSymbolicImpl begin
7480
Term => TERM
7581
Add => ADD
7682
Mul => MUL
@@ -81,7 +87,15 @@ function exprtype(x::BasicSymbolic)
8187
end
8288
end
8389

84-
const wvd = WeakValueDict{UInt, BasicSymbolic}()
90+
function Base.getproperty(x::BasicSymbolic, sym::Symbol)
91+
if sym === :metadata || sym === :impl
92+
return getfield(x, sym)
93+
else
94+
return getproperty(x.impl, sym)
95+
end
96+
end
97+
98+
const wvd = WeakValueDict{UInt, BasicSymbolicImpl}()
8599

86100
# Same but different error messages
87101
@noinline error_on_type() = error("Internal error: unreachable reached!")
@@ -99,7 +113,7 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
99113
nt = getproperties(obj)
100114
nt_new = merge(nt, patch)
101115
# Call outer constructor because hash consing cannot be applied in inner constructor
102-
@compactified obj::BasicSymbolic begin
116+
@compactified obj.impl::BasicSymbolicImpl begin
103117
Sym => Sym{T}(nt_new.name; nt_new...)
104118
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
105119
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
@@ -128,9 +142,12 @@ symtype(x) = typeof(x)
128142
@inline symtype(::Type{<:Symbolic{T}}) where T = T
129143

130144
# We're returning a function pointer
131-
@inline function operation(x::BasicSymbolic)
132-
@compactified x::BasicSymbolic begin
133-
Term => x.f
145+
function operation(x::BasicSymbolic)
146+
operation(x.impl)
147+
end
148+
@inline function operation(impl::BasicSymbolicImpl)
149+
@compactified impl::BasicSymbolicImpl begin
150+
Term => impl.f
134151
Add => (+)
135152
Mul => (*)
136153
Div => (/)
@@ -144,7 +161,7 @@ end
144161

145162
function TermInterface.sorted_arguments(x::BasicSymbolic)
146163
args = arguments(x)
147-
@compactified x::BasicSymbolic begin
164+
@compactified x.impl::BasicSymbolicImpl begin
148165
Add => @goto ADD
149166
Mul => @goto MUL
150167
_ => return args
@@ -169,7 +186,10 @@ end
169186
TermInterface.children(x::BasicSymbolic) = arguments(x)
170187
TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x)
171188
function TermInterface.arguments(x::BasicSymbolic)
172-
@compactified x::BasicSymbolic begin
189+
arguments(x.impl)
190+
end
191+
function TermInterface.arguments(x::BasicSymbolicImpl)
192+
@compactified x::BasicSymbolicImpl begin
173193
Term => return x.arguments
174194
Add => @goto ADDMUL
175195
Mul => @goto ADDMUL
@@ -219,7 +239,15 @@ end
219239
isexpr(s::BasicSymbolic) = !issym(s)
220240
iscall(s::BasicSymbolic) = isexpr(s)
221241

222-
@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false
242+
@inline function isa_SymType(T::Val{S}, x) where {S}
243+
if x isa BasicSymbolic
244+
Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.impl)
245+
elseif x isa BasicSymbolicImpl
246+
Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x)
247+
else
248+
false
249+
end
250+
end
223251

224252
"""
225253
issym(x)
@@ -395,7 +423,14 @@ end
395423
Base.one( s::Symbolic) = one( symtype(s))
396424
Base.zero(s::Symbolic) = zero(symtype(s))
397425

398-
Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name")
426+
Base.nameof(s::BasicSymbolic) = nameof(s.impl)
427+
function Base.nameof(s::BasicSymbolicImpl)
428+
if issym(s)
429+
s.name
430+
else
431+
error("None Sym BasicSymbolic doesn't have a name")
432+
end
433+
end
399434

400435
## This is much faster than hash of an array of Any
401436
hashvec(xs, z) = foldr(hash, xs, init=z)
@@ -458,7 +493,8 @@ function hash2(n::T, salt::UInt) where {T <: Number}
458493
hash(T, hash(n, salt))
459494
end
460495
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
461-
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
496+
hash2(s::BasicSymbolicImpl) = hash2(s, zero(UInt))
497+
function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T}
462498
E = exprtype(s)
463499
h::UInt = 0
464500
if E === SYM
@@ -520,7 +556,7 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h
520556
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
521557
original behavior of those functions.
522558
"""
523-
function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
559+
function BasicSymbolicImpl(s::BasicSymbolicImpl)::BasicSymbolicImpl
524560
if !ENABLE_HASHCONSING[]
525561
return s
526562
end
@@ -533,18 +569,20 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
533569
end
534570
end
535571

536-
function Sym{T}(name::Symbol; kw...) where {T}
572+
function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T}
537573
s = Sym{T}(; name, kw...)
538-
BasicSymbolic(s)
574+
bsi = BasicSymbolicImpl(s)
575+
BasicSymbolic(bsi, metadata)
539576
end
540577

541-
function Term{T}(f, args; kw...) where T
578+
function Term{T}(f, args; metadata = NO_METADATA, kw...) where T
542579
if eltype(args) !== Any
543580
args = convert(Vector{Any}, args)
544581
end
545582

546583
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
547-
BasicSymbolic(s)
584+
bsi = BasicSymbolicImpl(s)
585+
BasicSymbolic(bsi, metadata)
548586
end
549587

550588
function Term(f, args; metadata=NO_METADATA)

0 commit comments

Comments
 (0)