Skip to content

Commit c8010da

Browse files
committed
feat: more extensible expression interface
1 parent 3dbded5 commit c8010da

File tree

5 files changed

+80
-6
lines changed

5 files changed

+80
-6
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ import .NodeModule:
6666
@reexport import .EvaluationHelpersModule
6767
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
6868
@reexport import .RandomModule: NodeSampler
69-
@reexport import .ExpressionModule: AbstractExpression, Expression, with_tree
69+
@reexport import .ExpressionModule:
70+
AbstractExpression, Expression, with_tree, with_metadata, get_contents, get_metadata
7071
import .ExpressionModule:
7172
get_tree, get_operators, get_variable_names, Metadata, default_node_type, node_type
7273
@reexport import .ParseModule: @parse_expression, parse_expression

src/Expression.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,46 @@ end
140140
function set_constants!(ex::AbstractExpression{T}, constants, refs) where {T}
141141
return error("`set_constants!` function must be implemented for $(typeof(ex)) types.")
142142
end
143+
function get_contents(ex::AbstractExpression)
144+
return error("`get_contents` function must be implemented for $(typeof(ex)) types.")
145+
end
146+
function get_metadata(ex::AbstractExpression)
147+
return error("`get_metadata` function must be implemented for $(typeof(ex)) types.")
148+
end
143149
########################################################
144150

145151
"""
146152
with_tree(ex::AbstractExpression, tree::AbstractExpressionNode)
153+
with_tree(ex::AbstractExpression, tree::AbstractExpression)
147154
148155
Create a new expression based on `ex` but with a different `tree`
149156
"""
157+
function with_tree(ex::AbstractExpression, tree::AbstractExpression)
158+
return with_tree(ex, get_contents(tree))
159+
end
150160
function with_tree(ex::AbstractExpression, tree)
151-
return constructorof(typeof(ex))(tree, ex.metadata)
161+
return constructorof(typeof(ex))(tree, get_metadata(ex))
152162
end
163+
function get_contents(ex::Expression)
164+
return ex.tree
165+
end
166+
167+
"""
168+
with_metadata(ex::AbstractExpression, metadata)
169+
with_metadata(ex::AbstractExpression; metadata...)
170+
171+
Create a new expression based on `ex` but with a different `metadata`.
172+
"""
173+
function with_metadata(ex::AbstractExpression; metadata...)
174+
return with_metadata(get_contents(ex), Metadata(metadata))
175+
end
176+
function with_metadata(ex::AbstractExpression, metadata::Metadata)
177+
return constructorof(typeof(ex))(get_contents(ex), metadata)
178+
end
179+
function get_metadata(ex::Expression)
180+
return ex.metadata
181+
end
182+
153183
function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T,N}}
154184
return preserve_sharing(N)
155185
end

src/Interfaces.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ using ..ExpressionModule:
4343
get_tree,
4444
get_operators,
4545
get_variable_names,
46+
get_contents,
47+
get_metadata,
4648
with_tree,
49+
with_metadata,
4750
default_node_type
4851
using ..ParametricExpressionModule: ParametricExpression, ParametricNode
4952

@@ -52,6 +55,14 @@ using ..ParametricExpressionModule: ParametricExpression, ParametricNode
5255
###############################################################################
5356

5457
## mandatory
58+
function _check_get_contents(ex::AbstractExpression)
59+
new_ex = with_tree(ex, get_contents(ex))
60+
return new_ex == ex && new_ex isa typeof(ex)
61+
end
62+
function _check_get_metadata(ex::AbstractExpression)
63+
new_ex = with_metadata(ex, get_metadata(ex))
64+
return new_ex == ex && new_ex isa typeof(ex)
65+
end
5566
function _check_get_tree(ex::AbstractExpression{T,N}) where {T,N}
5667
return get_tree(ex) isa N
5768
end
@@ -67,6 +78,15 @@ function _check_copy(ex::AbstractExpression)
6778
# TODO: Could include checks for aliasing here
6879
return preserves
6980
end
81+
function _check_with_tree(ex::AbstractExpression)
82+
new_ex = with_tree(ex, get_contents(ex))
83+
new_ex2 = with_tree(ex, ex)
84+
return new_ex == ex && new_ex isa typeof(ex) && new_ex2 == ex && new_ex2 isa typeof(ex)
85+
end
86+
function _check_with_metadata(ex::AbstractExpression)
87+
new_ex = with_metadata(ex, get_metadata(ex))
88+
return new_ex == ex && new_ex isa typeof(ex)
89+
end
7090

7191
## optional
7292
function _check_count_nodes(ex::AbstractExpression)
@@ -116,10 +136,14 @@ end
116136
#! format: off
117137
ei_components = (
118138
mandatory = (
139+
get_contents = "extracts the runtime contents of an expression" => _check_get_contents,
140+
get_metadata = "extracts the runtime metadata of an expression" => _check_get_metadata,
119141
get_tree = "extracts the expression tree from [`AbstractExpression`](@ref)" => _check_get_tree,
120142
get_operators = "returns the operators used in the expression (or pass `operators` explicitly to override)" => _check_get_operators,
121143
get_variable_names = "returns the variable names used in the expression (or pass `variable_names` explicitly to override)" => _check_get_variable_names,
122144
copy = "returns a copy of the expression" => _check_copy,
145+
with_tree = "returns the expression with different tree" => _check_with_tree,
146+
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
123147
),
124148
optional = (
125149
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,

src/ParametricExpression.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ import ..EvaluateModule: eval_tree_array
1919
import ..EvaluateDerivativeModule: eval_grad_tree_array
2020
import ..EvaluationHelpersModule: _grad_evaluator
2121
import ..ExpressionModule:
22-
get_tree, get_operators, get_variable_names, max_feature, default_node_type
22+
get_contents,
23+
get_metadata,
24+
get_tree,
25+
get_operators,
26+
get_variable_names,
27+
max_feature,
28+
default_node_type
2329
import ..ParseModule: parse_leaf
2430

2531
"""A type of expression node that also stores a parameter index"""
@@ -127,9 +133,9 @@ end
127133
###############################################################################
128134
# Abstract expression interface ###############################################
129135
###############################################################################
130-
function get_tree(ex::ParametricExpression)
131-
return ex.tree
132-
end
136+
get_contents(ex::ParametricExpression) = ex.tree
137+
get_metadata(ex::ParametricExpression) = ex.metadata
138+
get_tree(ex::ParametricExpression) = ex.tree
133139
function get_operators(ex::ParametricExpression, operators=nothing)
134140
return operators === nothing ? ex.metadata.operators : operators
135141
end

test/test_multi_expression.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
trees::TREES
1111
metadata::Metadata{D}
1212

13+
function MultiScalarExpression(trees::NamedTuple, metadata::Metadata{D}) where {D}
14+
example_tree = first(values(trees))
15+
N = typeof(example_tree)
16+
T = eltype(example_tree)
17+
return new{T,N,typeof(trees),D}(trees, metadata)
18+
end
19+
1320
"""
1421
Create a multi-expression expression type.
1522
@@ -65,6 +72,12 @@
6572
end
6673

6774
tree_factory(f::F, trees) where {F} = f(; trees...)
75+
function DE.get_contents(ex::MultiScalarExpression)
76+
return ex.trees
77+
end
78+
function DE.get_metadata(ex::MultiScalarExpression)
79+
return ex.metadata
80+
end
6881
function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N}
6982
fused_expression = parse_expression(
7083
tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;

0 commit comments

Comments
 (0)