Skip to content

Minor fgen fixes #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
Expand Down
35 changes: 33 additions & 2 deletions src/grad/grad.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@

# Can remove the import statement once this is fully incorporated into SCMC
# import SourceCodeMcCormick: xstr, ystr, zstr, var_names, arity, op, transform_rule
struct GradTransform <: AbstractTransform end

var_names(::GradTransform, a::Real) = a, a
function var_names(::GradTransform, a::BasicSymbolic)
if exprtype(a)==SYM
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"))
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"))
return acvgrad.val, accgrad.val
elseif exprtype(a)==TERM
if varterm(a) && typeof(a.f)<:BasicSymbolic
arg_list = Symbol[]
for i in a.arguments
push!(arg_list, get_name(i))
end
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"), arg_list)
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"), arg_list)
return acvgrad.val, accgrad.val

Check warning on line 18 in src/grad/grad.jl

View check run for this annotation

Codecov / codecov/patch

src/grad/grad.jl#L4-L18

Added lines #L4 - L18 were not covered by tests
else
acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"))
accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"))
return acvgrad.val, accgrad.val

Check warning on line 22 in src/grad/grad.jl

View check run for this annotation

Codecov / codecov/patch

src/grad/grad.jl#L20-L22

Added lines #L20 - L22 were not covered by tests
end
else
error("Reached `var_names` with an unexpected type [ADD/MUL/DIV/POW]. Check expression factorization to make sure it is being binarized correctly.")

Check warning on line 25 in src/grad/grad.jl

View check run for this annotation

Codecov / codecov/patch

src/grad/grad.jl#L25

Added line #L25 was not covered by tests
end
end

function all_names(a::Any)
aL, aU = var_names(IntervalTransform(), a)
acv, acc = var_names(McCormickTransform(), a)
acvgrad, accgrad = var_names(GradTransform(), a)
return aL, aU, acv, acc, acvgrad, accgrad

Check warning on line 33 in src/grad/grad.jl

View check run for this annotation

Codecov / codecov/patch

src/grad/grad.jl#L29-L33

Added lines #L29 - L33 were not covered by tests
end


"""
Expand Down
2 changes: 1 addition & 1 deletion src/grad/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
IfElse.ifelse(yU < 0.0, mid_expr(ycc, ycv, yU).^(-1),
NaN))
y_cv_gradlist_inv = similar(cv_gradlist[:,y])
@. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)), (yU^-1 - yL^-1)/(yU - yL)) *
@. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)) * mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec), (yU^-1 - yL^-1)/(yU - yL)) *

Check warning on line 275 in src/grad/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/grad/rules.jl#L275

Added line #L275 was not covered by tests
mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec),
IfElse.ifelse(yL > 0.0, -1.0/(mid_expr(ycc, ycv, yU)*mid_expr(ycc, ycv, yU)) *
mid_grad(ycc, ycv, yU, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec),
Expand Down
15 changes: 14 additions & 1 deletion src/interval/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
get_name

Take a `BasicSymbolic` object such as `x[1,1]` and return a symbol like `:xv1v1`.
Note that this supports up to 9999-indexed variables (higher than that will still
work, but the order will be wrong)
"""
function get_name(a::BasicSymbolic)
if exprtype(a)==SYM
Expand All @@ -67,7 +69,18 @@
args = a.arguments
new_var = string(args[1])
for i in 2:lastindex(args)
new_var = new_var * "v" * string(args[i])
if args[i] < 10
new_var = new_var * "v000" * string(args[i])
elseif args[i] < 100
new_var = new_var * "v00" * string(args[i])
elseif args[i] < 1000
new_var = new_var * "v0" * string(args[i])
elseif args[i] < 10000
new_var = new_var * "v" * string(args[i])

Check warning on line 79 in src/interval/interval.jl

View check run for this annotation

Codecov / codecov/patch

src/interval/interval.jl#L72-L79

Added lines #L72 - L79 were not covered by tests
else
@warn "Index above 10000, order may be wrong"
new_var = new_var * "v" * string(args[i])

Check warning on line 82 in src/interval/interval.jl

View check run for this annotation

Codecov / codecov/patch

src/interval/interval.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
end
end
return Symbol(new_var)
else
Expand Down
1 change: 1 addition & 0 deletions src/transform/binarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
return nothing
end
end
binarize!(a::Real) = error("Attempting to apply binarize!() to a Real")

Check warning on line 38 in src/transform/binarize.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/binarize.jl#L38

Added line #L38 was not covered by tests
98 changes: 70 additions & 28 deletions src/transform/factor.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@

base_term(a::Any) = false
base_term(a::Real) = true
base_term(a::Num) = base_term(a.val)

Check warning on line 4 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L4

Added line #L4 was not covered by tests
function base_term(a::BasicSymbolic)
exprtype(a)==SYM && return true
exprtype(a)==TERM && return varterm(a) || (a.f==getindex)
return false
end

function isfactor(a::BasicSymbolic)
function isfactor(a::BasicSymbolic; split_div::Bool=false)
if exprtype(a)==SYM
return true
elseif exprtype(a)==TERM
Expand All @@ -32,7 +33,12 @@
~(base_term(key)) && return false
end
return true
elseif exprtype(a)==DIV
elseif exprtype(a)==DIV && split_div==true
~(typeof(a.num)<:Real) && return false
~(isone(a.num)) && return false
~(base_term(a.den)) && return false
return true

Check warning on line 40 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L37-L40

Added lines #L37 - L40 were not covered by tests
elseif exprtype(a)==DIV && split_div==false
~(base_term(a.num)) && return false
~(base_term(a.den)) && return false
return true
Expand All @@ -47,13 +53,13 @@
@warn """Use of "!" is deprecated as of v0.2.0. Please call `factor()` instead."""
return factor(a...)
end
factor(ex::Num) = factor(ex.val)
factor(ex::Num, eqs::Vector{Equation}) = factor(ex.val, eqs=eqs)
factor(ex::Num; split_div::Bool=false) = factor(ex.val, split_div=split_div)
factor(ex::Num, eqs::Vector{Equation}; split_div::Bool=false) = factor(ex.val, eqs=eqs, split_div=split_div)

Check warning on line 57 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L56-L57

Added lines #L56 - L57 were not covered by tests

function factor(old_ex::BasicSymbolic; eqs = Equation[])
function factor(old_ex::BasicSymbolic; eqs = Equation[], split_div::Bool = false)
ex = deepcopy(old_ex)
binarize!(ex)
if isfactor(ex)
if isfactor(ex, split_div=split_div)
index = findall(x -> isequal(x.rhs,ex), eqs)
if isempty(index)
newsym = gensym(:aux)
Expand Down Expand Up @@ -87,12 +93,12 @@
new_terms[eqs[index[1]].lhs] = 1
end
else
factor(val*key, eqs=eqs)
factor(val*key, eqs=eqs, split_div=split_div)

Check warning on line 96 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L96

Added line #L96 was not covered by tests
new_terms[eqs[end].lhs] = 1
end
end
new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms)
factor(new_add, eqs=eqs)
factor(new_add, eqs=eqs, split_div=split_div)

Check warning on line 101 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L101

Added line #L101 was not covered by tests
return eqs
elseif exprtype(ex)==MUL
new_terms = Dict{Any, Number}()
Expand All @@ -112,57 +118,93 @@
new_terms[eqs[index[1]].lhs] = 1
end
else
factor(key^val, eqs=eqs)
factor(key^val, eqs=eqs, split_div=split_div)

Check warning on line 121 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L121

Added line #L121 was not covered by tests
new_terms[eqs[end].lhs] = 1
end
end
new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms)
factor(new_mul, eqs=eqs)
factor(new_mul, eqs=eqs, split_div=split_div)

Check warning on line 126 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L126

Added line #L126 was not covered by tests
return eqs
elseif exprtype(ex)==DIV
if base_term(ex.num)
new_num = ex.num
else
factor(ex.num, eqs=eqs)
new_num = eqs[end].lhs
end
if base_term(ex.den)
new_den = ex.den
if split_div
if isone(Num(ex.num))
if base_term(ex.den)
new_den = ex.den

Check warning on line 132 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L129-L132

Added lines #L129 - L132 were not covered by tests
else
factor(ex.den, eqs=eqs, split_div=split_div)
new_den = eqs[end].lhs

Check warning on line 135 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L134-L135

Added lines #L134 - L135 were not covered by tests
end
new_div = SymbolicUtils.Div(1, new_den)
factor(new_div, eqs=eqs, split_div=split_div)
return eqs

Check warning on line 139 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L137-L139

Added lines #L137 - L139 were not covered by tests
else
new_terms = Dict{Any, Number}()
coeff = 1
if base_term(ex.num)
if typeof(ex.num)<:Real
coeff = ex.num

Check warning on line 145 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L141-L145

Added lines #L141 - L145 were not covered by tests
else
new_terms[ex.num] = 1

Check warning on line 147 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L147

Added line #L147 was not covered by tests
end
else
factor(ex.num, eqs=eqs, split_div=split_div)
new_terms[eqs[end].lhs] = 1

Check warning on line 151 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
end
if base_term(ex.den)
new_terms[1/ex.den] = 1

Check warning on line 154 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L153-L154

Added lines #L153 - L154 were not covered by tests
else
factor(1/ex.den, eqs=eqs, split_div=split_div)
new_terms[eqs[end].lhs] = 1

Check warning on line 157 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
end
new_mul = SymbolicUtils.Mul(Real, coeff, new_terms)
factor(new_mul, eqs=eqs, split_div=split_div)
return eqs

Check warning on line 161 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L159-L161

Added lines #L159 - L161 were not covered by tests
end
else
factor(ex.den, eqs=eqs)
new_den = eqs[end].lhs
if base_term(ex.num)
new_num = ex.num

Check warning on line 165 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
else
factor(ex.num, eqs=eqs, split_div=split_div)
new_num = eqs[end].lhs

Check warning on line 168 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L167-L168

Added lines #L167 - L168 were not covered by tests
end
if base_term(ex.den)
new_den = ex.den

Check warning on line 171 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L170-L171

Added lines #L170 - L171 were not covered by tests
else
factor(ex.den, eqs=eqs, split_div=split_div)
new_den = eqs[end].lhs

Check warning on line 174 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L173-L174

Added lines #L173 - L174 were not covered by tests
end
new_div = SymbolicUtils.Div(new_num, new_den)
factor(new_div, eqs=eqs, split_div=split_div)
return eqs

Check warning on line 178 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L176-L178

Added lines #L176 - L178 were not covered by tests
end
new_div = SymbolicUtils.Div(new_num, new_den)
factor(new_div, eqs=eqs)
return eqs
elseif exprtype(ex)==POW
if base_term(ex.base)
new_base = ex.base
else
factor(ex.base, eqs=eqs)
factor(ex.base, eqs=eqs, split_div=split_div)

Check warning on line 184 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L184

Added line #L184 was not covered by tests
new_base = eqs[end].lhs
end
if base_term(ex.exp)
new_exp = ex.exp
else
factor(ex.exp, eqs=eqs)
factor(ex.exp, eqs=eqs, split_div=split_div)

Check warning on line 190 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L190

Added line #L190 was not covered by tests
new_exp = eqs[end].lhs
end
new_pow = SymbolicUtils.Pow(new_base, new_exp)
factor(new_pow, eqs=eqs)
factor(new_pow, eqs=eqs, split_div=split_div)

Check warning on line 194 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L194

Added line #L194 was not covered by tests
return eqs
elseif exprtype(ex)==TERM
new_args = []
for arg in ex.arguments
if base_term(arg)
push!(new_args, arg)
else
factor(arg, eqs=eqs)
factor(arg, eqs=eqs, split_div=split_div)

Check warning on line 202 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L202

Added line #L202 was not covered by tests
push!(new_args, eqs[end].lhs)
end
end
new_func = SymbolicUtils.Term(ex.f, new_args)
factor(new_func, eqs=eqs)
factor(new_func, eqs=eqs, split_div=split_div)

Check warning on line 207 in src/transform/factor.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/factor.jl#L207

Added line #L207 was not covered by tests
return eqs
end
return eqs
Expand Down
73 changes: 65 additions & 8 deletions src/transform/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,34 @@
function sort_vars(strings::Vector{String})
# isempty(strings) && return
sort_names = fill("", length(strings))

# Step 1) Check for derivative-type variables
# @show strings
# split_strings = string.(hcat(split.(strings, "_")...)[1,:])
split_strings = first.(split.(strings, "_"))
if strings == split_strings
# Simpler case; we can sort more-or-less normally

# Put constants first, if any exist
for i in eachindex(split_strings)
if split_strings[i]=="constant"
split_strings[i] = "_____constant"
end
end
return sortperm(split_strings)

# Sort split_strings, and if the strings follow the pattern [letters][numbers],
# sort by [letters] first and then by [numbers]. Otherwise, treat the string
# just as a normal string
return sortperm(split_strings, by = s -> begin
m = match(r"([a-zA-Z]+)(\d+)", s)
if m !== nothing
prefix = m.captures[1]
number = parse(Int, m.captures[2])
return (prefix, number)

Check warning on line 316 in src/transform/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/utilities.jl#L314-L316

Added lines #L314 - L316 were not covered by tests
else
return (s, 0)
end
end)
end
deriv = fill(false, length(strings))
# Here's a way to check for derivatives if we need to go back to "d" instead of '∂'
Expand Down Expand Up @@ -399,16 +415,16 @@

function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{String})
if exprtype(term)==SYM
if ~(string(term) in strings)
push!(strings, string(term))
if ~(string(get_name(term)) in strings)
push!(strings, string(get_name(term)))
push!(vars, term)
return vars, strings
end
return vars, strings
end
if exprtype(term)==TERM && varterm(term)
if ~(string(term.f) in strings) && (term.f==getindex && ~(string(term) in string.(vars)))
push!(strings, string(term))
push!(strings, string(get_name(term)))

Check warning on line 427 in src/transform/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/utilities.jl#L427

Added line #L427 was not covered by tests
push!(vars, term)
return vars, strings
end
Expand All @@ -421,8 +437,8 @@
end
~(typeof(arg)<:BasicSymbolic) ? continue : nothing
if exprtype(arg)==SYM
if ~(string(arg) in strings)
push!(strings, string(arg))
if ~(string(get_name(arg)) in strings)
push!(strings, string(get_name(arg)))
push!(vars, arg)
end
elseif typeof(arg) <: Real
Expand Down Expand Up @@ -450,8 +466,8 @@
```
eqs = [y ~ 15*x,
z ~ (1+y)^2]
shrink_eqs(eqs, 1)

julia> shrink_eqs(eqs, 1)
1-element Vector{Equation}:
z ~ (1 + 15x)^2
```
Expand All @@ -476,6 +492,47 @@
return new_eqs
end

"""
extract(::Vector{Equation})
extract(::Vector{Equation}, ::Int)

Given a set of symbolic equations, and optinally a specific
element that you would like expanded into a full expression,
progressively substitute the RHS definitions of LHS terms
until there is only that equation remaining (default = end).
Returns the RHS of that expression as an object of type
SymbolicUtils.BasicSymbolic{Real}.

# Example

```
eqs = [y ~ 15*x,
z ~ (1+y)^2]

julia> extract(eqs)
(1 + 15x)^2
```
"""
function extract(eqs::Vector{Equation}, ID::Int=length(eqs))
final_expr = eqs[ID].rhs
progress = true
while progress
progress = false
for var in pull_vars(final_expr)
eq_ID = findfirst(x -> isequal(x.lhs, var), eqs)
if !isnothing(eq_ID)
if isequal(eqs[eq_ID].lhs, eqs[eq_ID].rhs)
nothing

Check warning on line 525 in src/transform/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/utilities.jl#L516-L525

Added lines #L516 - L525 were not covered by tests
else
final_expr = substitute(final_expr, Dict(eqs[eq_ID].lhs => eqs[eq_ID].rhs))
progress = true

Check warning on line 528 in src/transform/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/utilities.jl#L527-L528

Added lines #L527 - L528 were not covered by tests
end
end
end
end
return final_expr

Check warning on line 533 in src/transform/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/utilities.jl#L531-L533

Added lines #L531 - L533 were not covered by tests
end

"""
convex_evaluator(::Num)
convex_evaluator(::Equation)
Expand Down
Loading
Loading