Skip to content

Fix hashing and memoization of enodes (VecExpr) #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 27 additions & 52 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,6 @@ Returns the canonical e-class id for a given e-class.

@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]

# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
# if !v_isexpr(n)
# v_hash!(n)
# return n
# end
# l = v_arity(n)
# new_n = v_new(l)
# v_set_flag!(new_n, v_flags(n))
# v_set_head!(new_n, v_head(n))
# for i in v_children_range(n)
# @inbounds new_n[i] = find(g, n[i])
# end
# v_hash!(new_n)
# new_n
# end

function canonicalize!(g::EGraph, n::VecExpr)
v_isexpr(n) || @goto ret
for i in (VECEXPR_META_LENGTH + 1):length(n)
Expand All @@ -253,19 +237,16 @@ end

function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, n) ? find(g, g.memo[n]) : 0
id = get(g.memo, n, zero(Id))
iszero(id) ? id : find(g, id)
end


function add_class_by_op(g::EGraph, n, eclass_id)
key = IdKey(v_signature(n))
if haskey(g.classes_by_op, key)
push!(g.classes_by_op[key], eclass_id)
else
g.classes_by_op[key] = [eclass_id]
end
vec = get!(g.classes_by_op, key, Vector{Id}())
push!(vec, eclass_id)
end

"""
Expand All @@ -274,7 +255,8 @@ Inserts an e-node in an [`EGraph`](@ref)
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

haskey(g.memo, n) && return g.memo[n]
id = get(g.memo, n, zero(Id))
iszero(id) || return id

if should_copy
n = copy(n)
Expand All @@ -291,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
Expand Down Expand Up @@ -320,28 +302,22 @@ function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)

n = if isexpr(e)
args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)

h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
n
else # constant enode
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))
# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
id = add!(g, n, false)
return id

add!(g, n, false)
end

"""
Expand Down Expand Up @@ -431,10 +407,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending) || !isempty(g.analysis_pending)
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
node = copy(node)
canonicalize!(g, node)
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -474,9 +450,8 @@ function check_memo(g::EGraph)::Bool
for (id, class) in g.classes
@assert id.val == class.id
for node in class.nodes
if haskey(test_memo, node)
old_id = test_memo[node]
test_memo[node] = id.val
old_id = get!(test_memo, node, id.val)
if old_id != id.val
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/uniquequeue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ function Base.pop!(uq::UniqueQueue{T}) where {T}
v
end

Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
2 changes: 1 addition & 1 deletion src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
Expand Down
Loading