Skip to content

Commit 233729e

Browse files
feat: use a global atomic UInt64 to give each BasicSymbolic a unique ID
1 parent 675f8a2 commit 233729e

File tree

3 files changed

+49
-84
lines changed

3 files changed

+49
-84
lines changed

src/cache.jl

+19-28
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@ struct CacheSentinel end
88
"""
99
$(TYPEDEF)
1010
11-
Struct wrapping the `objectid` of a `BasicSymbolic`, since arguments annotated
11+
Struct wrapping the `id` of a `BasicSymbolic`, since arguments annotated
1212
`::Union{BasicSymbolic, UInt}` would not be able to differentiate between looking
1313
up a symbolic or a `UInt`.
1414
"""
1515
struct SymbolicKey
16-
id::UInt
16+
id::UInt64
1717
end
1818

19+
"""
20+
$(TYPEDSIGNATURES)
21+
22+
The key stored in the cache for a particular value. Returns a `SymbolicKey` for
23+
`BasicSymbolic` and is the identity function otherwise.
24+
"""
25+
# can't dispatch because `BasicSymbolic` isn't defined here
26+
get_cache_key(x) = x isa BasicSymbolic ? SymbolicKey(x.id[]) : x
27+
1928
"""
2029
associated_cache(fn)
2130
@@ -233,10 +242,6 @@ macro cache(args...)
233242
cache_value_name = :val
234243
# The condition for a cache hit
235244
cache_hit_condition = :(!($cache_value_name isa $CacheSentinel))
236-
# Type of additional data stored with cached result. Used to compare
237-
# equality of `BasicSymbolic` arguments, since `objectid` is a hash.
238-
cache_additional_types = []
239-
cache_additional_values = []
240245

241246
for arg in fn.args
242247
# handle arguments with defaults
@@ -245,24 +250,18 @@ macro cache(args...)
245250
end
246251
if !Meta.isexpr(arg, :(::))
247252
# if the type is `Any`, branch on it being a `BasicSymbolic`
248-
push!(keyexprs, :($arg isa BasicSymbolic ? $SymbolicKey(objectid($arg)) : $arg))
253+
push!(keyexprs, :($get_cache_key($arg)))
249254
push!(argexprs, arg)
250255
push!(keytypes, Any)
251-
push!(cache_additional_types, Any)
252-
push!(cache_additional_values, arg)
253-
cache_hit_condition = :($cache_hit_condition && (!($arg isa BasicSymbolic) || $arg === $cache_value_name[$(length(cache_additional_values))]))
254256
continue
255257
end
256258
argname, Texpr = arg.args
257259
push!(argexprs, argname)
258260

259261
if Texpr == :Any
260262
# if the type is `Any`, branch on it being a `BasicSymbolic`
261-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
263+
push!(keyexprs, :($get_cache_key($argname)))
262264
push!(keytypes, Any)
263-
push!(cache_additional_types, Any)
264-
push!(cache_additional_values, argname)
265-
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
266265
continue
267266
end
268267

@@ -274,10 +273,7 @@ macro cache(args...)
274273
maybe_basicsymbolic = any(x -> x <: BasicSymbolic, Ts)
275274
push!(keytypes, Union{keyTs...})
276275
if maybe_basicsymbolic
277-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
278-
push!(cache_additional_types, Texpr)
279-
push!(cache_additional_values, argname)
280-
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
276+
push!(keyexprs, :($get_cache_key($argname)))
281277
else
282278
push!(keyexprs, argname)
283279
end
@@ -288,10 +284,7 @@ macro cache(args...)
288284
T = Base.eval(__module__, Texpr)
289285
if T <: BasicSymbolic
290286
push!(keytypes, SymbolicKey)
291-
push!(keyexprs, :($SymbolicKey(objectid($argname))))
292-
push!(cache_additional_types, T)
293-
push!(cache_additional_values, argname)
294-
cache_hit_condition = :($cache_hit_condition && $argname === $cache_value_name[$(length(cache_additional_values))])
287+
push!(keyexprs, :($get_cache_key($argname)))
295288
else
296289
push!(keytypes, T)
297290
push!(keyexprs, argname)
@@ -312,9 +305,7 @@ macro cache(args...)
312305
# construct an expression for the type of the cache keys
313306
keyT = Expr(:curly, Tuple)
314307
append!(keyT.args, keytypes)
315-
valT = Expr(:curly, Tuple)
316-
append!(valT.args, cache_additional_types)
317-
push!(valT.args, rettype)
308+
valT = rettype
318309
# the type of the cache
319310
cacheT = :(Dict{$keyT, $valT})
320311
# type of the `TaskLocalValue`
@@ -363,7 +354,7 @@ macro cache(args...)
363354
if $cache_hit_condition
364355
# cache hit
365356
cachestats.hits += 1
366-
return $cache_value_name[end]
357+
return $cache_value_name
367358
end
368359
# cache miss
369360
cachestats.misses += 1
@@ -374,8 +365,8 @@ macro cache(args...)
374365
$(filter!)($cachename, cachedict)
375366
end
376367
# add to cache
377-
cachedict[key] = ($(cache_additional_values...), val)
378-
return val
368+
cachedict[key] = $cache_value_name
369+
return $cache_value_name
379370
end
380371

381372
# if we're not doing caching

src/types.jl

+21-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const ENABLE_HASHCONSING = Ref(true)
2525
@compactify show_methods=false begin
2626
@abstract struct BasicSymbolic{T} <: Symbolic{T}
2727
metadata::Metadata = NO_METADATA
28+
id::RefValue{UInt64} = Ref{UInt64}(0)
2829
end
2930
struct Sym{T} <: BasicSymbolic{T}
3031
name::Symbol = :OOF
@@ -114,11 +115,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
114115
# Call outer constructor because hash consing cannot be applied in inner constructor
115116
@compactified obj::BasicSymbolic begin
116117
Sym => Sym{T}(nt_new.name; nt_new...)
117-
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
118-
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
119-
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
120-
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
121-
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
118+
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
119+
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
120+
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
121+
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
122+
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
122123
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
123124
end
124125
end
@@ -262,6 +263,7 @@ end
262263

263264
function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
264265
a === b && return true
266+
a.id == b.id && a.id != 0 && return true
265267

266268
E = exprtype(a)
267269
E === exprtype(b) || return false
@@ -305,6 +307,7 @@ function.
305307
"""
306308
function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S}
307309
a === b && return true
310+
a.id == b.id && a.id != 0 && return true
308311

309312
E = exprtype(a)
310313
E === exprtype(b) || return false
@@ -523,6 +526,12 @@ end
523526
### Constructors
524527
###
525528

529+
mutable struct AtomicIDCounter
530+
@atomic x::UInt64
531+
end
532+
533+
const ID_COUNTER = AtomicIDCounter(0)
534+
526535
"""
527536
$(TYPEDSIGNATURES)
528537
@@ -552,6 +561,7 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
552561
hcw = HashConsingWrapper(s)
553562
k = getkey(cache, hcw, nothing)
554563
if isnothing(k)
564+
hcw.bs.id[] = @atomic ID_COUNTER.x += 1
555565
cache[hcw] = nothing
556566
return s
557567
else
@@ -560,7 +570,7 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
560570
end
561571

562572
function Sym{T}(name::Symbol; kw...) where {T}
563-
s = Sym{T}(; name, kw...)
573+
s = Sym{T}(; name, kw..., id = Ref{UInt}(0))
564574
BasicSymbolic(s)
565575
end
566576

@@ -576,7 +586,7 @@ function Term{T}(f, args; kw...) where T
576586
end
577587
unwrap_arr!(args)
578588

579-
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
589+
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw..., id = Ref{UInt64}(0))
580590
BasicSymbolic(s)
581591
end
582592

@@ -606,7 +616,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
606616
end
607617
end
608618

609-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
619+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw..., id = Ref{UInt64}(0))
610620
BasicSymbolic(s)
611621
end
612622

@@ -624,7 +634,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
624634
else
625635
coeff = a
626636
dict = b
627-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
637+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw..., id = Ref{UInt64}(0))
628638
BasicSymbolic(s)
629639
end
630640
end
@@ -692,7 +702,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
692702
end
693703
end
694704

695-
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
705+
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata, id = Ref{UInt64}(0))
696706
BasicSymbolic(s)
697707
end
698708

@@ -712,7 +722,7 @@ function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
712722
b = unwrap(b)
713723
_iszero(b) && return 1
714724
_isone(b) && return a
715-
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
725+
s = Pow{T}(; base=a, exp=b, arguments=[], metadata, id = Ref{UInt64}(0))
716726
BasicSymbolic(s)
717727
end
718728

test/cache_macro.jl

+9-45
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SymbolicUtils
22
using SymbolicUtils: BasicSymbolic, @cache, associated_cache, set_limit!, get_limit,
3-
clear_cache!, SymbolicKey, metadata, maketerm
3+
clear_cache!, SymbolicKey, metadata, maketerm, get_cache_key
44
using OhMyThreads: tmap
55
using Random
66

@@ -14,9 +14,9 @@ end
1414
@test isequal(val, 2x + 1)
1515
cachestruct = associated_cache(f1)
1616
cache, stats = cachestruct.tlv[]
17-
@test cache isa Dict{Tuple{SymbolicKey}, Tuple{BasicSymbolic, BasicSymbolic}}
17+
@test cache isa Dict{Tuple{SymbolicKey}, BasicSymbolic}
1818
@test length(cache) == 1
19-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
19+
@test cache[(get_cache_key(x),)] === val
2020
@test stats.hits == 0
2121
@test stats.misses == 1
2222
f1(x)
@@ -76,20 +76,20 @@ end
7676
@test isequal(val, 2x + 1)
7777
cachestruct = associated_cache(f2)
7878
cache, stats = cachestruct.tlv[]
79-
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, NTuple{2, Union{BasicSymbolic, UInt}}}
79+
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
8080
@test length(cache) == 1
81-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
81+
@test cache[(get_cache_key(x),)] === val
8282
@test stats.hits == 0
8383
@test stats.misses == 1
8484
f2(x)
8585
@test stats.hits == 1
8686
@test stats.misses == 1
8787

88-
y = objectid(x)
88+
y = get_cache_key(x).id
8989
val = f2(y)
9090
@test val == 2y + 1
9191
@test length(cache) == 2
92-
@test cache[(y,)][end] == val
92+
@test cache[(y,)] == val
9393
@test stats.misses == 2
9494

9595
clear_cache!(f2)
@@ -111,9 +111,9 @@ end
111111
@test isequal(val, 2x + 1)
112112
cachestruct = associated_cache(fn)
113113
cache, stats = cachestruct.tlv[]
114-
@test cache isa Dict{Tuple{Any}, Tuple{Any, Union{BasicSymbolic, Int}}}
114+
@test cache isa Dict{Tuple{Any}, Union{BasicSymbolic, Int}}
115115
@test length(cache) == 1
116-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
116+
@test cache[(get_cache_key(x),)] === val
117117
@test stats.hits == 0
118118
@test stats.misses == 1
119119
fn(x)
@@ -160,39 +160,3 @@ end
160160
truevals = map(f4, exprs)
161161
@test isequal(result, truevals)
162162
end
163-
164-
@cache function f5(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
165-
return x + y + z
166-
end
167-
168-
# temporary definition to induce objectid collisions
169-
Base.objectid(x::BasicSymbolic) = 0x42
170-
171-
@testset "`objectid` collision handling" begin
172-
@syms x y z
173-
@test objectid(x) == objectid(y) == objectid(z) == 0x42
174-
cachestruct = associated_cache(f5)
175-
cache, stats = cachestruct.tlv[]
176-
val = f5(x, 1, 2)
177-
@test isequal(val, x + 3)
178-
@test length(cache) == 1
179-
@test stats.misses == 1
180-
val2 = f5(y, 1, 2)
181-
@test isequal(val2, y + 3)
182-
@test length(cache) == 1
183-
@test stats.misses == 2
184-
185-
clear_cache!(f5)
186-
val = f5(x, y, z)
187-
@test isequal(val, x + y + z)
188-
@test length(cache) == 1
189-
@test stats.misses == 1
190-
val2 = f5(y, 2z, x)
191-
@test isequal(val2, x + y + 2z)
192-
@test length(cache) == 1
193-
@test stats.misses == 2
194-
end
195-
196-
Base.delete_method(only(methods(objectid, @__MODULE__)))
197-
@syms x
198-
@test objectid(x) != 0x42

0 commit comments

Comments
 (0)