Skip to content

Commit 2a5109c

Browse files
committed
Convert common sub-functions as common sub-expressions
1 parent 1694a00 commit 2a5109c

File tree

3 files changed

+224
-3
lines changed

3 files changed

+224
-3
lines changed

src/Nonlinear/parse.jl

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,77 @@ function parse_expression(::Model, ::Expression, x::Any, ::Int)
3636
)
3737
end
3838

39+
function _extract_subexpression!(expr::Expression, root::Int)
40+
n = length(expr.nodes)
41+
# The whole subexpression is continuous in the tape
42+
first_out = first_value = last_value = nothing
43+
for i in root:n
44+
node = expr.nodes[i]
45+
if i != root && node.parent < root
46+
first_out = i
47+
break
48+
end
49+
index = node.index
50+
if node.type == NODE_VALUE
51+
if isnothing(first_value)
52+
first_value = node.index
53+
last_value = first_value
54+
else
55+
last_value = node.index
56+
end
57+
index -= first_value - 1
58+
end
59+
expr.nodes[i] =
60+
Node(node.type, index, i == root ? -1 : node.parent - root + 1)
61+
end
62+
if isnothing(first_out)
63+
I = root:n
64+
else
65+
I = root:(first_out-1)
66+
end
67+
if isnothing(first_value)
68+
V = nothing
69+
else
70+
V = first_value:last_value
71+
end
72+
if !isnothing(first_out)
73+
for i in (last(I)+1):n
74+
node = expr.nodes[i]
75+
index = node.index
76+
if node.type == NODE_VALUE && !isnothing(V)
77+
@assert index >= last(V)
78+
index -= length(V)
79+
end
80+
parent = node.parent
81+
if parent > root
82+
@assert parent > last(I)
83+
parent -= length(I) - 1
84+
end
85+
expr.nodes[i] = Node(node.type, index, parent)
86+
end
87+
end
88+
return I, V
89+
end
90+
91+
function _extract_subexpression!(data::Model, expr::Expression, root::Int)
92+
parent = expr.nodes[root].parent
93+
I, V = _extract_subexpression!(expr, root)
94+
subexpr =
95+
Expression(expr.nodes[I], isnothing(V) ? Float64[] : expr.values[V])
96+
push!(data.expressions, subexpr)
97+
index = ExpressionIndex(length(data.expressions))
98+
expr.nodes[root] = Node(NODE_SUBEXPRESSION, index.value, parent)
99+
if length(I) > 1
100+
deleteat!(expr.nodes, I[2:end])
101+
if !isnothing(V)
102+
deleteat!(expr.values, V)
103+
end
104+
else
105+
@assert isnothing(V)
106+
end
107+
return index, I
108+
end
109+
39110
function parse_expression(
40111
data::Model,
41112
expr::Expression,
@@ -46,7 +117,59 @@ function parse_expression(
46117
while !isempty(stack)
47118
parent_node, arg = pop!(stack)
48119
if arg isa MOI.ScalarNonlinearFunction
49-
_parse_without_recursion_inner(stack, data, expr, arg, parent_node)
120+
if haskey(data.cache, arg)
121+
subexpr = data.cache[arg]
122+
if subexpr isa Tuple{Expression,Int}
123+
_expr, _node = subexpr
124+
subexpr, I = _extract_subexpression!(data, _expr, _node)
125+
if expr === _expr
126+
if parent_node > first(I)
127+
@assert parent_node > last(I)
128+
parent_node -= length(I) - 1
129+
end
130+
for i in eachindex(stack)
131+
_parent_node = stack[i][1]
132+
if _parent_node > first(I)
133+
@assert _parent_node > last(I)
134+
stack[i] =
135+
(_parent_node - length(I) + 1, stack[i][2])
136+
end
137+
end
138+
end
139+
for (key, val) in data.cache
140+
if val isa Tuple{Expression,Int}
141+
__expr, __node = val
142+
if _expr === __expr && __node > first(I)
143+
if __node <= last(I)
144+
data.cache[key] = (
145+
data.expressions[subexpr.value],
146+
__node - first(I) + 1,
147+
)
148+
else
149+
data.cache[key] =
150+
(__expr, __node - length(I) + 1)
151+
end
152+
end
153+
end
154+
end
155+
data.cache[arg] = subexpr
156+
end
157+
parse_expression(
158+
data,
159+
expr,
160+
subexpr::ExpressionIndex,
161+
parent_node,
162+
)
163+
else
164+
_parse_without_recursion_inner(
165+
stack,
166+
data,
167+
expr,
168+
arg,
169+
parent_node,
170+
)
171+
data.cache[arg] = (expr, length(expr.nodes))
172+
end
50173
else
51174
# We can use recursion here, because ScalarNonlinearFunction only
52175
# occur in other ScalarNonlinearFunction.
@@ -82,7 +205,7 @@ function _parse_without_recursion_inner(stack, data, expr, x, parent)
82205
parent = length(expr.nodes)
83206
# Args need to be pushed onto the stack in reverse because the stack is a
84207
# first-in last-out datastructure.
85-
for arg in reverse(x.args)
208+
for arg in Iterators.Reverse(x.args)
86209
push!(stack, (parent, arg))
87210
end
88211
return

src/Nonlinear/types.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ tree.
7676
struct Expression
7777
nodes::Vector{Node}
7878
values::Vector{Float64}
79-
Expression() = new(Node[], Float64[])
8079
end
8180

81+
Expression() = Expression(Node[], Float64[])
82+
8283
function Base.:(==)(x::Expression, y::Expression)
8384
return x.nodes == y.nodes && x.values == y.values
8485
end
@@ -165,6 +166,11 @@ mutable struct Model
165166
operators::OperatorRegistry
166167
# This is a private field, used only to increment the ConstraintIndex.
167168
last_constraint_index::Int64
169+
# This is a private field, used to detect common subexpressions.
170+
cache::Dict{
171+
MOI.ScalarNonlinearFunction,
172+
Union{ExpressionIndex,Tuple{Expression,Int}},
173+
}
168174
function Model()
169175
return new(
170176
nothing,
@@ -173,6 +179,10 @@ mutable struct Model
173179
Float64[],
174180
OperatorRegistry(),
175181
0,
182+
Dict{
183+
MOI.ScalarNonlinearFunction,
184+
Union{ExpressionIndex,Tuple{Expression,Int}},
185+
}(),
176186
)
177187
end
178188
end

test/Nonlinear/Nonlinear.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,94 @@ function test_intercept_ForwardDiff_MethodError()
14461446
return
14471447
end
14481448

1449+
function test_extract_subexpression()
1450+
model = Nonlinear.Model()
1451+
x = MOI.VariableIndex(1)
1452+
sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3])
1453+
f = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1454+
expr = Nonlinear.parse_expression(model, f)
1455+
display(expr.nodes)
1456+
@test expr == Nonlinear.Expression(
1457+
[
1458+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1),
1459+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1460+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1461+
],
1462+
Float64[],
1463+
)
1464+
expected_sub = Nonlinear.Expression(
1465+
[
1466+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 4, -1)
1467+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1)
1468+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1)
1469+
],
1470+
[3.0],
1471+
)
1472+
@test model.expressions == [expected_sub]
1473+
@test model.cache[sub] == Nonlinear.ExpressionIndex(1)
1474+
1475+
h = MOI.ScalarNonlinearFunction(:*, Any[2, sub, 1])
1476+
g = MOI.ScalarNonlinearFunction(:+, Any[sub, h])
1477+
expr = MOI.Nonlinear.parse_expression(model, g)
1478+
expected_g = Nonlinear.Expression(
1479+
[
1480+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1)
1481+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1)
1482+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1)
1483+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 3)
1484+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 3)
1485+
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 3)
1486+
],
1487+
[2.0, 1.0],
1488+
)
1489+
@test expr == expected_g
1490+
# It should have detected the sub-expressions that was the same as `f`
1491+
@test model.expressions == [expected_sub]
1492+
# This means that it didn't get to extract from `g`, let's also test
1493+
# with extraction by starting with an empty model
1494+
1495+
model = Nonlinear.Model()
1496+
MOI.Nonlinear.set_objective(model, g)
1497+
@test model.objective == expected_g
1498+
@test model.expressions == [expected_sub]
1499+
# Test that the objective function gets rewritten as we reuse `h`
1500+
# Also test that we don't change the parents in the stack of `h`
1501+
# by creating a long stack
1502+
prod = MOI.ScalarNonlinearFunction(:*, [h, x])
1503+
sum = MOI.ScalarNonlinearFunction(:*, [x, x, x, x, prod])
1504+
expr = Nonlinear.parse_expression(model, sum)
1505+
@test isempty(model.objective.values)
1506+
@test model.objective.nodes == [
1507+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1),
1508+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1509+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 1),
1510+
]
1511+
@test model.expressions == [
1512+
expected_sub,
1513+
Nonlinear.Expression(
1514+
[
1515+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1),
1516+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1),
1517+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1518+
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 1),
1519+
],
1520+
[2.0, 1.0],
1521+
),
1522+
]
1523+
@test isempty(expr.values)
1524+
@test expr.nodes == [
1525+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1),
1526+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1527+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1528+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1529+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1530+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1),
1531+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 6),
1532+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 6),
1533+
]
1534+
return
1535+
end
1536+
14491537
end # TestNonlinear
14501538

14511539
TestNonlinear.runtests()

0 commit comments

Comments
 (0)