Skip to content

Graph evaluator caching #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"

[extensions]
DynamicExpressionsBumperExt = "Bumper"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
DynamicExpressionsOptimExt = "Optim"
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"

[compat]
Bumper = "0.6"
ChainRulesCore = "1"
Compat = "3.37, 4"
DispatchDoctor = "0.4"
Interfaces = "0.3"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1"
PackageExtensionCompat = "1"
Expand All @@ -47,7 +46,6 @@ julia = "1.6"

[extras]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
66 changes: 64 additions & 2 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module DynamicExpressionsLoopVectorizationExt

using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions

using LoopVectorization: @turbo, vmapnt
using DynamicExpressions: AbstractExpressionNode, GraphNode, OperatorEnum
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
import DynamicExpressions.EvaluateModule:
Expand All @@ -14,6 +16,7 @@ import DynamicExpressions.EvaluateModule:
deg2_r0_eval
import DynamicExpressions.ExtensionInterfaceModule:
_is_loopvectorization_loaded, bumper_kern1!, bumper_kern2!
import DynamicExpressions.ValueInterfaceModule: is_valid, is_valid_array

_is_loopvectorization_loaded(::Int) = true

Expand Down Expand Up @@ -230,4 +233,63 @@ function bumper_kern2!(
return cumulator1
end



# graph eval

function DynamicExpressions.EvaluateModule._eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{true}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loopVectorization::Val{true}
::EvalOptions{true}

) where {T}

# vmap is faster with small cX sizes
# vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?)

order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
node.cache = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(operators.unaops[node.op], node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
else
if node.r.constant
node.constant = false
node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
else
node.constant = false
node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
end
end
end
if root.constant
return ResultOk(fill(root.val, size(cX, 2)), true)
else
return ResultOk(root.cache, true)
end
end

end
9 changes: 6 additions & 3 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ import .ValueInterfaceModule:
set_node!,
tree_mapreduce,
filter_map,
filter_map!
filter_map!,
topological_sort,
randomised_topological_sort
import .NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
max_degree,
leaf_copy,
branch_copy,
leaf_hash,
Expand All @@ -66,8 +69,7 @@ import .NodeModule:
count_scalar_constants,
get_scalar_constants,
set_scalar_constants!
@reexport import .StringsModule: string_tree, print_tree
import .StringsModule: get_op_name
@reexport import .StringsModule: string_tree, print_tree, get_op_name
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
Expand Down Expand Up @@ -104,6 +106,7 @@ end
import .InterfacesModule:
ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except


function __init__()
@require_extensions
end
Expand Down
157 changes: 156 additions & 1 deletion src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module EvaluateModule

using DispatchDoctor: @stable, @unstable

import ..NodeModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort
import ..StringsModule: string_tree
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: fill_similar, counttuple, ResultOk
Expand Down Expand Up @@ -854,4 +854,159 @@ end
end
end

# Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly

# overwritten in ext/DynamicExpressionsLoopVectorizationExt.jl
function _eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{true}
) where {T}
error("DynamicExpressionsLoopVectorizationExt did not overwrite _eval_graph_array")
end

function _eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{false}
) where {T}
order = topological_sort(root)
skip = true
for node in order
skip &= !node.modified
if skip continue end
node.modified = false
if node.degree == 0 && !node.constant
node.cache = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = map(operators.unaops[node.op], node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = map(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache)
if !is_valid_array(cache[node]) return ResultOk(node.cache, false) end
end
else
if node.r.constant
node.constant = false
node.cache = map(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
else
node.constant = false
node.cache = map(operators.binops[node.op], node.l.cache, node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
end
end
end
if root.constant
return ResultOk(fill(root.val, size(cX, 2)), true)
else
return ResultOk(root.cache, true)
end
end

function eval_tree_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
eval_options::Union{EvalOptions,Nothing}=nothing
) where {T}

if eval_options.turbo isa Val{true} || isnothing(eval_eval_options) && _is_loopvectorization_loaded(0)
return _eval_graph_array(root, cX, operators, Val(true))
else
return _eval_graph_array(root, cX, operators, Val(false))
end
end

function eval_graph_array_diff(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
) where {T}

# vmap is faster with small cX sizes
# vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?)
dp = Dict{GraphNode, AbstractArray{T}}()
order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
dp[node] = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return false end
else
node.constant = false
dp[node] = map(operators.unaops[node.op], dp[node.l])
if !is_valid_array(dp[node]) return false end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return false end
else
node.constant = false
dp[node] = map(Base.Fix1(operators.binops[node.op], node.l.val), dp[node.r])
if !is_valid_array(dp[node]) return false end
end
else
if node.r.constant
node.constant = false
dp[node] = map(Base.Fix2(operators.binops[node.op], node.r.val), dp[node.l])
if !is_valid_array(dp[node]) return false end
else
node.constant = false
dp[node] = map(operators.binops[node.op], dp[node.l], dp[node.r])
if !is_valid_array(dp[node]) return false end
end
end
end
end
if root.constant
return fill(root.val, size(cX, 2))
else
return dp[root]
end
end

function eval_graph_single(
root::GraphNode{T},
cX::AbstractArray{T},
operators::OperatorEnum
) where {T}
order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
node.val = cX[node.feature]
elseif node.degree == 1
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return false end
elseif node.degree == 2
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return false end
end
end
return root.val
end

end
Loading
Loading