@@ -24,53 +24,59 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT)
24
24
const ENABLE_HASHCONSING = Ref (true )
25
25
26
26
@compactify show_methods= false begin
27
- @abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
28
- metadata:: Metadata = NO_METADATA
29
- end
30
- mutable struct Sym{T} <: BasicSymbolic{T}
27
+ @abstract mutable struct BasicSymbolicImpl{T} end
28
+ mutable struct Sym{T} <: BasicSymbolicImpl{T}
31
29
name:: Symbol = :OOF
32
30
end
33
- mutable struct Term{T} <: BasicSymbolic {T}
31
+ mutable struct Term{T} <: BasicSymbolicImpl {T}
34
32
f:: Any = identity # base/num if Pow; issorted if Add/Dict
35
33
arguments:: Vector{Any} = EMPTY_ARGS
36
34
hash:: RefValue{UInt} = EMPTY_HASH
37
35
hash2:: RefValue{UInt} = EMPTY_HASH
38
36
end
39
- mutable struct Mul{T} <: BasicSymbolic {T}
37
+ mutable struct Mul{T} <: BasicSymbolicImpl {T}
40
38
coeff:: Any = 0 # exp/den if Pow
41
39
dict:: EMPTY_DICT_T = EMPTY_DICT
42
40
hash:: RefValue{UInt} = EMPTY_HASH
43
41
hash2:: RefValue{UInt} = EMPTY_HASH
44
42
arguments:: Vector{Any} = EMPTY_ARGS
45
43
issorted:: RefValue{Bool} = NOT_SORTED
46
44
end
47
- mutable struct Add{T} <: BasicSymbolic {T}
45
+ mutable struct Add{T} <: BasicSymbolicImpl {T}
48
46
coeff:: Any = 0 # exp/den if Pow
49
47
dict:: EMPTY_DICT_T = EMPTY_DICT
50
48
hash:: RefValue{UInt} = EMPTY_HASH
51
49
hash2:: RefValue{UInt} = EMPTY_HASH
52
50
arguments:: Vector{Any} = EMPTY_ARGS
53
51
issorted:: RefValue{Bool} = NOT_SORTED
54
52
end
55
- mutable struct Div{T} <: BasicSymbolic {T}
53
+ mutable struct Div{T} <: BasicSymbolicImpl {T}
56
54
num:: Any = 1
57
55
den:: Any = 1
58
56
simplified:: Bool = false
59
57
arguments:: Vector{Any} = EMPTY_ARGS
60
58
end
61
- mutable struct Pow{T} <: BasicSymbolic {T}
59
+ mutable struct Pow{T} <: BasicSymbolicImpl {T}
62
60
base:: Any = 1
63
61
exp:: Any = 1
64
62
arguments:: Vector{Any} = EMPTY_ARGS
65
63
end
66
64
end
67
65
66
+ @kwdef struct BasicSymbolic{T} <: Symbolic{T}
67
+ impl:: BasicSymbolicImpl{T}
68
+ metadata:: Metadata = NO_METADATA
69
+ end
70
+
68
71
function SymbolicIndexingInterface. symbolic_type (:: Type{<:BasicSymbolic} )
69
72
ScalarSymbolic ()
70
73
end
71
74
72
75
function exprtype (x:: BasicSymbolic )
73
- @compactified x:: BasicSymbolic begin
76
+ exprtype (x. impl)
77
+ end
78
+ function exprtype (impl:: BasicSymbolicImpl )
79
+ @compactified impl:: BasicSymbolicImpl begin
74
80
Term => TERM
75
81
Add => ADD
76
82
Mul => MUL
@@ -81,7 +87,15 @@ function exprtype(x::BasicSymbolic)
81
87
end
82
88
end
83
89
84
- const wvd = WeakValueDict {UInt, BasicSymbolic} ()
90
+ function Base. getproperty (x:: BasicSymbolic , sym:: Symbol )
91
+ if sym === :metadata || sym === :impl
92
+ return getfield (x, sym)
93
+ else
94
+ return getproperty (x. impl, sym)
95
+ end
96
+ end
97
+
98
+ const wvd = WeakValueDict {UInt, BasicSymbolicImpl} ()
85
99
86
100
# Same but different error messages
87
101
@noinline error_on_type () = error (" Internal error: unreachable reached!" )
@@ -99,7 +113,7 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
99
113
nt = getproperties (obj)
100
114
nt_new = merge (nt, patch)
101
115
# Call outer constructor because hash consing cannot be applied in inner constructor
102
- @compactified obj:: BasicSymbolic begin
116
+ @compactified obj. impl :: BasicSymbolicImpl begin
103
117
Sym => Sym {T} (nt_new. name; nt_new... )
104
118
Term => Term {T} (nt_new. f, nt_new. arguments; nt_new... , hash = RefValue (UInt (0 )), hash2 = RefValue (UInt (0 )))
105
119
Add => Add (T, nt_new. coeff, nt_new. dict; nt_new... , hash = RefValue (UInt (0 )), hash2 = RefValue (UInt (0 )))
@@ -128,9 +142,12 @@ symtype(x) = typeof(x)
128
142
@inline symtype (:: Type{<:Symbolic{T}} ) where T = T
129
143
130
144
# We're returning a function pointer
131
- @inline function operation (x:: BasicSymbolic )
132
- @compactified x:: BasicSymbolic begin
133
- Term => x. f
145
+ function operation (x:: BasicSymbolic )
146
+ operation (x. impl)
147
+ end
148
+ @inline function operation (impl:: BasicSymbolicImpl )
149
+ @compactified impl:: BasicSymbolicImpl begin
150
+ Term => impl. f
134
151
Add => (+ )
135
152
Mul => (* )
136
153
Div => (/ )
144
161
145
162
function TermInterface. sorted_arguments (x:: BasicSymbolic )
146
163
args = arguments (x)
147
- @compactified x:: BasicSymbolic begin
164
+ @compactified x. impl :: BasicSymbolicImpl begin
148
165
Add => @goto ADD
149
166
Mul => @goto MUL
150
167
_ => return args
169
186
TermInterface. children (x:: BasicSymbolic ) = arguments (x)
170
187
TermInterface. sorted_children (x:: BasicSymbolic ) = sorted_arguments (x)
171
188
function TermInterface. arguments (x:: BasicSymbolic )
172
- @compactified x:: BasicSymbolic begin
189
+ arguments (x. impl)
190
+ end
191
+ function TermInterface. arguments (x:: BasicSymbolicImpl )
192
+ @compactified x:: BasicSymbolicImpl begin
173
193
Term => return x. arguments
174
194
Add => @goto ADDMUL
175
195
Mul => @goto ADDMUL
219
239
isexpr (s:: BasicSymbolic ) = ! issym (s)
220
240
iscall (s:: BasicSymbolic ) = isexpr (s)
221
241
222
- @inline isa_SymType (T:: Val{S} , x) where {S} = x isa BasicSymbolic ? Unityper. isa_type_fun (Val (SymbolicUtils. BasicSymbolic), T, x) : false
242
+ @inline function isa_SymType (T:: Val{S} , x) where {S}
243
+ if x isa BasicSymbolic
244
+ Unityper. isa_type_fun (Val (SymbolicUtils. BasicSymbolicImpl), T, x. impl)
245
+ elseif x isa BasicSymbolicImpl
246
+ Unityper. isa_type_fun (Val (SymbolicUtils. BasicSymbolicImpl), T, x)
247
+ else
248
+ false
249
+ end
250
+ end
223
251
224
252
"""
225
253
issym(x)
395
423
Base. one ( s:: Symbolic ) = one ( symtype (s))
396
424
Base. zero (s:: Symbolic ) = zero (symtype (s))
397
425
398
- Base. nameof (s:: BasicSymbolic ) = issym (s) ? s. name : error (" None Sym BasicSymbolic doesn't have a name" )
426
+ Base. nameof (s:: BasicSymbolic ) = nameof (s. impl)
427
+ function Base. nameof (s:: BasicSymbolicImpl )
428
+ if issym (s)
429
+ s. name
430
+ else
431
+ error (" None Sym BasicSymbolic doesn't have a name" )
432
+ end
433
+ end
399
434
400
435
# # This is much faster than hash of an array of Any
401
436
hashvec (xs, z) = foldr (hash, xs, init= z)
@@ -458,7 +493,8 @@ function hash2(n::T, salt::UInt) where {T <: Number}
458
493
hash (T, hash (n, salt))
459
494
end
460
495
hash2 (s:: BasicSymbolic ) = hash2 (s, zero (UInt))
461
- function hash2 (s:: BasicSymbolic{T} , salt:: UInt ):: UInt where {T}
496
+ hash2 (s:: BasicSymbolicImpl ) = hash2 (s, zero (UInt))
497
+ function hash2 (s:: BasicSymbolicImpl{T} , salt:: UInt ):: UInt where {T}
462
498
E = exprtype (s)
463
499
h:: UInt = 0
464
500
if E === SYM
@@ -520,7 +556,7 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h
520
556
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
521
557
original behavior of those functions.
522
558
"""
523
- function BasicSymbolic (s:: BasicSymbolic ):: BasicSymbolic
559
+ function BasicSymbolicImpl (s:: BasicSymbolicImpl ):: BasicSymbolicImpl
524
560
if ! ENABLE_HASHCONSING[]
525
561
return s
526
562
end
@@ -533,18 +569,20 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
533
569
end
534
570
end
535
571
536
- function Sym {T} (name:: Symbol ; kw... ) where {T}
572
+ function Sym {T} (name:: Symbol ; metadata = NO_METADATA, kw... ) where {T}
537
573
s = Sym {T} (; name, kw... )
538
- BasicSymbolic (s)
574
+ bsi = BasicSymbolicImpl (s)
575
+ BasicSymbolic (bsi, metadata)
539
576
end
540
577
541
- function Term {T} (f, args; kw... ) where T
578
+ function Term {T} (f, args; metadata = NO_METADATA, kw... ) where T
542
579
if eltype (args) != = Any
543
580
args = convert (Vector{Any}, args)
544
581
end
545
582
546
583
s = Term {T} (;f= f, arguments= args, hash= Ref (UInt (0 )), hash2= Ref (UInt (0 )), kw... )
547
- BasicSymbolic (s)
584
+ bsi = BasicSymbolicImpl (s)
585
+ BasicSymbolic (bsi, metadata)
548
586
end
549
587
550
588
function Term (f, args; metadata= NO_METADATA)
0 commit comments