Skip to content

Commit fe56817

Browse files
authored
Minor fgen fixes (#11)
* Minor fgen fixes - Add `var_names(::GradTransform, [...])` and `all_names` for generating symbolic subgradient variables - Fix inversion operation in division when lower and upper bounds are the same - Add more support for higher variable indices - Add error for `binarize!` to assist in debugging - Add `split_div::Bool` argument to `factor`, which determines whether terms like `x/y` will remain as-is or get factored into `x * (1/y)` - Improve variable sorting so that, e.g., `x10` will not come before `x2` - Add `extract`, which can pull a subexpression out of a primal trace and substitute out auxiliary variables - Fix `fgen` bug that would use an incorrect subgradient variable * Update ci.yml * Remove `fgen` division tests - Tests for division operations removed for `fgen`. Rules were written to match McCormick.jl, which was modified after this issue: PSORLab/McCormick.jl#69. This change only affects division by a negative McCormick object when the convex and concave relaxation values are not the same, but since `fgen` will be deprecated in v0.5, this issue in `fgen` is not planned to be fixed.
1 parent 57cba3f commit fe56817

File tree

10 files changed

+199
-48
lines changed

10 files changed

+199
-48
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
with:
3535
version: ${{ matrix.version }}
3636
arch: ${{ matrix.arch }}
37-
- uses: actions/cache@v1
37+
- uses: actions/cache@v4
3838
env:
3939
cache-name: cache-artifacts
4040
with:

src/grad/grad.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11

2-
# Can remove the import statement once this is fully incorporated into SCMC
3-
# import SourceCodeMcCormick: xstr, ystr, zstr, var_names, arity, op, transform_rule
2+
struct GradTransform <: AbstractTransform end
3+
4+
var_names(::GradTransform, a::Real) = a, a
5+
function var_names(::GradTransform, a::BasicSymbolic)
6+
if exprtype(a)==SYM
7+
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"))
8+
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"))
9+
return acvgrad.val, accgrad.val
10+
elseif exprtype(a)==TERM
11+
if varterm(a) && typeof(a.f)<:BasicSymbolic
12+
arg_list = Symbol[]
13+
for i in a.arguments
14+
push!(arg_list, get_name(i))
15+
end
16+
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"), arg_list)
17+
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"), arg_list)
18+
return acvgrad.val, accgrad.val
19+
else
20+
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"))
21+
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"))
22+
return acvgrad.val, accgrad.val
23+
end
24+
else
25+
error("Reached `var_names` with an unexpected type [ADD/MUL/DIV/POW]. Check expression factorization to make sure it is being binarized correctly.")
26+
end
27+
end
28+
29+
function all_names(a::Any)
30+
aL, aU = var_names(IntervalTransform(), a)
31+
acv, acc = var_names(McCormickTransform(), a)
32+
acvgrad, accgrad = var_names(GradTransform(), a)
33+
return aL, aU, acv, acc, acvgrad, accgrad
34+
end
435

536

637
"""

src/grad/rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function grad_transform!(::McCormickIntervalTransform, ::typeof(/), zL, zU, zcv,
272272
IfElse.ifelse(yU < 0.0, mid_expr(ycc, ycv, yU).^(-1),
273273
NaN))
274274
y_cv_gradlist_inv = similar(cv_gradlist[:,y])
275-
@. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)), (yU^-1 - yL^-1)/(yU - yL)) *
275+
@. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)) * mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec), (yU^-1 - yL^-1)/(yU - yL)) *
276276
mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec),
277277
IfElse.ifelse(yL > 0.0, -1.0/(mid_expr(ycc, ycv, yU)*mid_expr(ycc, ycv, yU)) *
278278
mid_grad(ycc, ycv, yU, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec),

src/interval/interval.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ end
5757
get_name
5858
5959
Take a `BasicSymbolic` object such as `x[1,1]` and return a symbol like `:xv1v1`.
60+
Note that this supports up to 9999-indexed variables (higher than that will still
61+
work, but the order will be wrong)
6062
"""
6163
function get_name(a::BasicSymbolic)
6264
if exprtype(a)==SYM
@@ -67,7 +69,18 @@ function get_name(a::BasicSymbolic)
6769
args = a.arguments
6870
new_var = string(args[1])
6971
for i in 2:lastindex(args)
70-
new_var = new_var * "v" * string(args[i])
72+
if args[i] < 10
73+
new_var = new_var * "v000" * string(args[i])
74+
elseif args[i] < 100
75+
new_var = new_var * "v00" * string(args[i])
76+
elseif args[i] < 1000
77+
new_var = new_var * "v0" * string(args[i])
78+
elseif args[i] < 10000
79+
new_var = new_var * "v" * string(args[i])
80+
else
81+
@warn "Index above 10000, order may be wrong"
82+
new_var = new_var * "v" * string(args[i])
83+
end
7184
end
7285
return Symbol(new_var)
7386
else

src/transform/binarize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ function binarize!(ex::BasicSymbolic)
3535
return nothing
3636
end
3737
end
38+
binarize!(a::Real) = error("Attempting to apply binarize!() to a Real")

src/transform/factor.jl

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11

22
base_term(a::Any) = false
33
base_term(a::Real) = true
4+
base_term(a::Num) = base_term(a.val)
45
function base_term(a::BasicSymbolic)
56
exprtype(a)==SYM && return true
67
exprtype(a)==TERM && return varterm(a) || (a.f==getindex)
78
return false
89
end
910

10-
function isfactor(a::BasicSymbolic)
11+
function isfactor(a::BasicSymbolic; split_div::Bool=false)
1112
if exprtype(a)==SYM
1213
return true
1314
elseif exprtype(a)==TERM
@@ -32,7 +33,12 @@ function isfactor(a::BasicSymbolic)
3233
~(base_term(key)) && return false
3334
end
3435
return true
35-
elseif exprtype(a)==DIV
36+
elseif exprtype(a)==DIV && split_div==true
37+
~(typeof(a.num)<:Real) && return false
38+
~(isone(a.num)) && return false
39+
~(base_term(a.den)) && return false
40+
return true
41+
elseif exprtype(a)==DIV && split_div==false
3642
~(base_term(a.num)) && return false
3743
~(base_term(a.den)) && return false
3844
return true
@@ -47,13 +53,13 @@ function factor!(a...)
4753
@warn """Use of "!" is deprecated as of v0.2.0. Please call `factor()` instead."""
4854
return factor(a...)
4955
end
50-
factor(ex::Num) = factor(ex.val)
51-
factor(ex::Num, eqs::Vector{Equation}) = factor(ex.val, eqs=eqs)
56+
factor(ex::Num; split_div::Bool=false) = factor(ex.val, split_div=split_div)
57+
factor(ex::Num, eqs::Vector{Equation}; split_div::Bool=false) = factor(ex.val, eqs=eqs, split_div=split_div)
5258

53-
function factor(old_ex::BasicSymbolic; eqs = Equation[])
59+
function factor(old_ex::BasicSymbolic; eqs = Equation[], split_div::Bool = false)
5460
ex = deepcopy(old_ex)
5561
binarize!(ex)
56-
if isfactor(ex)
62+
if isfactor(ex, split_div=split_div)
5763
index = findall(x -> isequal(x.rhs,ex), eqs)
5864
if isempty(index)
5965
newsym = gensym(:aux)
@@ -87,12 +93,12 @@ function factor(old_ex::BasicSymbolic; eqs = Equation[])
8793
new_terms[eqs[index[1]].lhs] = 1
8894
end
8995
else
90-
factor(val*key, eqs=eqs)
96+
factor(val*key, eqs=eqs, split_div=split_div)
9197
new_terms[eqs[end].lhs] = 1
9298
end
9399
end
94100
new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms)
95-
factor(new_add, eqs=eqs)
101+
factor(new_add, eqs=eqs, split_div=split_div)
96102
return eqs
97103
elseif exprtype(ex)==MUL
98104
new_terms = Dict{Any, Number}()
@@ -112,57 +118,93 @@ function factor(old_ex::BasicSymbolic; eqs = Equation[])
112118
new_terms[eqs[index[1]].lhs] = 1
113119
end
114120
else
115-
factor(key^val, eqs=eqs)
121+
factor(key^val, eqs=eqs, split_div=split_div)
116122
new_terms[eqs[end].lhs] = 1
117123
end
118124
end
119125
new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms)
120-
factor(new_mul, eqs=eqs)
126+
factor(new_mul, eqs=eqs, split_div=split_div)
121127
return eqs
122128
elseif exprtype(ex)==DIV
123-
if base_term(ex.num)
124-
new_num = ex.num
125-
else
126-
factor(ex.num, eqs=eqs)
127-
new_num = eqs[end].lhs
128-
end
129-
if base_term(ex.den)
130-
new_den = ex.den
129+
if split_div
130+
if isone(Num(ex.num))
131+
if base_term(ex.den)
132+
new_den = ex.den
133+
else
134+
factor(ex.den, eqs=eqs, split_div=split_div)
135+
new_den = eqs[end].lhs
136+
end
137+
new_div = SymbolicUtils.Div(1, new_den)
138+
factor(new_div, eqs=eqs, split_div=split_div)
139+
return eqs
140+
else
141+
new_terms = Dict{Any, Number}()
142+
coeff = 1
143+
if base_term(ex.num)
144+
if typeof(ex.num)<:Real
145+
coeff = ex.num
146+
else
147+
new_terms[ex.num] = 1
148+
end
149+
else
150+
factor(ex.num, eqs=eqs, split_div=split_div)
151+
new_terms[eqs[end].lhs] = 1
152+
end
153+
if base_term(ex.den)
154+
new_terms[1/ex.den] = 1
155+
else
156+
factor(1/ex.den, eqs=eqs, split_div=split_div)
157+
new_terms[eqs[end].lhs] = 1
158+
end
159+
new_mul = SymbolicUtils.Mul(Real, coeff, new_terms)
160+
factor(new_mul, eqs=eqs, split_div=split_div)
161+
return eqs
162+
end
131163
else
132-
factor(ex.den, eqs=eqs)
133-
new_den = eqs[end].lhs
164+
if base_term(ex.num)
165+
new_num = ex.num
166+
else
167+
factor(ex.num, eqs=eqs, split_div=split_div)
168+
new_num = eqs[end].lhs
169+
end
170+
if base_term(ex.den)
171+
new_den = ex.den
172+
else
173+
factor(ex.den, eqs=eqs, split_div=split_div)
174+
new_den = eqs[end].lhs
175+
end
176+
new_div = SymbolicUtils.Div(new_num, new_den)
177+
factor(new_div, eqs=eqs, split_div=split_div)
178+
return eqs
134179
end
135-
new_div = SymbolicUtils.Div(new_num, new_den)
136-
factor(new_div, eqs=eqs)
137-
return eqs
138180
elseif exprtype(ex)==POW
139181
if base_term(ex.base)
140182
new_base = ex.base
141183
else
142-
factor(ex.base, eqs=eqs)
184+
factor(ex.base, eqs=eqs, split_div=split_div)
143185
new_base = eqs[end].lhs
144186
end
145187
if base_term(ex.exp)
146188
new_exp = ex.exp
147189
else
148-
factor(ex.exp, eqs=eqs)
190+
factor(ex.exp, eqs=eqs, split_div=split_div)
149191
new_exp = eqs[end].lhs
150192
end
151193
new_pow = SymbolicUtils.Pow(new_base, new_exp)
152-
factor(new_pow, eqs=eqs)
194+
factor(new_pow, eqs=eqs, split_div=split_div)
153195
return eqs
154196
elseif exprtype(ex)==TERM
155197
new_args = []
156198
for arg in ex.arguments
157199
if base_term(arg)
158200
push!(new_args, arg)
159201
else
160-
factor(arg, eqs=eqs)
202+
factor(arg, eqs=eqs, split_div=split_div)
161203
push!(new_args, eqs[end].lhs)
162204
end
163205
end
164206
new_func = SymbolicUtils.Term(ex.f, new_args)
165-
factor(new_func, eqs=eqs)
207+
factor(new_func, eqs=eqs, split_div=split_div)
166208
return eqs
167209
end
168210
return eqs

src/transform/utilities.jl

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,18 +290,34 @@ end
290290
function sort_vars(strings::Vector{String})
291291
# isempty(strings) && return
292292
sort_names = fill("", length(strings))
293-
293+
294294
# Step 1) Check for derivative-type variables
295295
# @show strings
296296
# split_strings = string.(hcat(split.(strings, "_")...)[1,:])
297297
split_strings = first.(split.(strings, "_"))
298298
if strings == split_strings
299+
# Simpler case; we can sort more-or-less normally
300+
301+
# Put constants first, if any exist
299302
for i in eachindex(split_strings)
300303
if split_strings[i]=="constant"
301304
split_strings[i] = "_____constant"
302305
end
303306
end
304-
return sortperm(split_strings)
307+
308+
# Sort split_strings, and if the strings follow the pattern [letters][numbers],
309+
# sort by [letters] first and then by [numbers]. Otherwise, treat the string
310+
# just as a normal string
311+
return sortperm(split_strings, by = s -> begin
312+
m = match(r"([a-zA-Z]+)(\d+)", s)
313+
if m !== nothing
314+
prefix = m.captures[1]
315+
number = parse(Int, m.captures[2])
316+
return (prefix, number)
317+
else
318+
return (s, 0)
319+
end
320+
end)
305321
end
306322
deriv = fill(false, length(strings))
307323
# Here's a way to check for derivatives if we need to go back to "d" instead of '∂'
@@ -399,16 +415,16 @@ end
399415

400416
function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{String})
401417
if exprtype(term)==SYM
402-
if ~(string(term) in strings)
403-
push!(strings, string(term))
418+
if ~(string(get_name(term)) in strings)
419+
push!(strings, string(get_name(term)))
404420
push!(vars, term)
405421
return vars, strings
406422
end
407423
return vars, strings
408424
end
409425
if exprtype(term)==TERM && varterm(term)
410426
if ~(string(term.f) in strings) && (term.f==getindex && ~(string(term) in string.(vars)))
411-
push!(strings, string(term))
427+
push!(strings, string(get_name(term)))
412428
push!(vars, term)
413429
return vars, strings
414430
end
@@ -421,8 +437,8 @@ function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{Stri
421437
end
422438
~(typeof(arg)<:BasicSymbolic) ? continue : nothing
423439
if exprtype(arg)==SYM
424-
if ~(string(arg) in strings)
425-
push!(strings, string(arg))
440+
if ~(string(get_name(arg)) in strings)
441+
push!(strings, string(get_name(arg)))
426442
push!(vars, arg)
427443
end
428444
elseif typeof(arg) <: Real
@@ -450,8 +466,8 @@ number of equations remaining (default = 4).
450466
```
451467
eqs = [y ~ 15*x,
452468
z ~ (1+y)^2]
453-
shrink_eqs(eqs, 1)
454469
470+
julia> shrink_eqs(eqs, 1)
455471
1-element Vector{Equation}:
456472
z ~ (1 + 15x)^2
457473
```
@@ -476,6 +492,47 @@ function shrink_eqs(eqs::Vector{Equation}, keep::Int64=4; force::Bool=false)
476492
return new_eqs
477493
end
478494

495+
"""
496+
extract(::Vector{Equation})
497+
extract(::Vector{Equation}, ::Int)
498+
499+
Given a set of symbolic equations, and optinally a specific
500+
element that you would like expanded into a full expression,
501+
progressively substitute the RHS definitions of LHS terms
502+
until there is only that equation remaining (default = end).
503+
Returns the RHS of that expression as an object of type
504+
SymbolicUtils.BasicSymbolic{Real}.
505+
506+
# Example
507+
508+
```
509+
eqs = [y ~ 15*x,
510+
z ~ (1+y)^2]
511+
512+
julia> extract(eqs)
513+
(1 + 15x)^2
514+
```
515+
"""
516+
function extract(eqs::Vector{Equation}, ID::Int=length(eqs))
517+
final_expr = eqs[ID].rhs
518+
progress = true
519+
while progress
520+
progress = false
521+
for var in pull_vars(final_expr)
522+
eq_ID = findfirst(x -> isequal(x.lhs, var), eqs)
523+
if !isnothing(eq_ID)
524+
if isequal(eqs[eq_ID].lhs, eqs[eq_ID].rhs)
525+
nothing
526+
else
527+
final_expr = substitute(final_expr, Dict(eqs[eq_ID].lhs => eqs[eq_ID].rhs))
528+
progress = true
529+
end
530+
end
531+
end
532+
end
533+
return final_expr
534+
end
535+
479536
"""
480537
convex_evaluator(::Num)
481538
convex_evaluator(::Equation)

0 commit comments

Comments
 (0)