Skip to content

Commit a189a89

Browse files
authored
Add ScalarNonlinearFunction (#2059)
1 parent c18c8b2 commit a189a89

31 files changed

+1480
-17
lines changed

docs/src/manual/standard_form.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ The function types implemented in MathOptInterface.jl are:
3737
| [`VariableIndex`](@ref) | ``x_j``, the projection onto a single coordinate defined by a variable index ``j``. |
3838
| [`VectorOfVariables`](@ref) | The projection onto multiple coordinates (that is, extracting a sub-vector). |
3939
| [`ScalarAffineFunction`](@ref) | ``a^T x + b``, where ``a`` is a vector and ``b`` scalar. |
40+
| [`ScalarNonlinearFunction`](@ref) | ``f(x)``, where ``f`` is a nonlinear function. |
4041
| [`VectorAffineFunction`](@ref) | ``A x + b``, where ``A`` is a matrix and ``b`` is a vector. |
4142
| [`ScalarQuadraticFunction`](@ref) | ``\frac{1}{2} x^T Q x + a^T x + b``, where ``Q`` is a symmetric matrix, ``a`` is a vector, and ``b`` is a constant. |
4243
| [`VectorQuadraticFunction`](@ref) | A vector of scalar-valued quadratic functions. |

docs/src/reference/errors.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ ModifyObjectiveNotAllowed
7575
DeleteNotAllowed
7676
UnsupportedSubmittable
7777
SubmitNotAllowed
78+
UnsupportedNonlinearOperator
7879
```
7980

8081
Note that setting the [`ConstraintFunction`](@ref) of a [`VariableIndex`](@ref)

docs/src/reference/models.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ ListOfOptimizerAttributesSet
5050
ListOfModelAttributesSet
5151
ListOfVariableAttributesSet
5252
ListOfConstraintAttributesSet
53+
UserDefinedFunction
54+
ListOfSupportedNonlinearOperators
5355
```
5456

5557
## Optimizer interface

docs/src/reference/standard_form.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ ScalarAffineTerm
2525
ScalarAffineFunction
2626
ScalarQuadraticTerm
2727
ScalarQuadraticFunction
28+
ScalarNonlinearFunction
2829
```
2930

3031
## Vector functions

src/Bridges/Objective/bridges/slack.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function bridge_objective(
7070
end
7171
constraint = MOI.Utilities.normalize_and_add_constraint(model, f, set)
7272
MOI.set(model, MOI.ObjectiveFunction{MOI.VariableIndex}(), slack)
73-
return SlackBridge{T,F,G}(slack, constraint, MOI.constant(f))
73+
return SlackBridge{T,F,G}(slack, constraint, MOI.constant(f, T))
7474
end
7575

7676
function supports_objective_function(
@@ -166,7 +166,11 @@ function MOI.get(
166166
bridge::SlackBridge{T,F,G},
167167
) where {T,F,G<:MOI.AbstractScalarFunction}
168168
func = MOI.get(model, MOI.ConstraintFunction(), bridge.constraint)
169-
f = MOI.Utilities.operate(+, T, func, bridge.constant)
169+
f = if !iszero(bridge.constant)
170+
MOI.Utilities.operate(+, T, func, bridge.constant)
171+
else
172+
func
173+
end
170174
g = MOI.Utilities.remove_variable(f, bridge.slack)
171175
return MOI.Utilities.convert_approx(G, g)
172176
end

src/Nonlinear/model.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,7 @@ function evaluate(
327327
end
328328
return storage[1]
329329
end
330+
331+
function MOI.get(model::Model, attr::MOI.ListOfSupportedNonlinearOperators)
332+
return MOI.get(model.operators, attr)
333+
end

src/Nonlinear/operators.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,60 @@ end
7474
DEFAULT_UNIVARIATE_OPERATORS
7575
7676
The list of univariate operators that are supported by default.
77+
78+
## Example
79+
80+
```jldoctest
81+
julia> import MathOptInterface as MOI
82+
83+
julia> MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS
84+
72-element Vector{Symbol}:
85+
:+
86+
:-
87+
:abs
88+
:sqrt
89+
:cbrt
90+
:abs2
91+
:inv
92+
:log
93+
:log10
94+
:log2
95+
96+
:airybi
97+
:airyaiprime
98+
:airybiprime
99+
:besselj0
100+
:besselj1
101+
:bessely0
102+
:bessely1
103+
:erfcx
104+
:dawson
105+
```
77106
"""
78107
const DEFAULT_UNIVARIATE_OPERATORS = first.(SYMBOLIC_UNIVARIATE_EXPRESSIONS)
79108

80109
"""
81110
DEFAULT_MULTIVARIATE_OPERATORS
82111
83112
The list of multivariate operators that are supported by default.
113+
114+
## Example
115+
116+
```jldoctest
117+
julia> import MathOptInterface as MOI
118+
119+
julia> MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS
120+
9-element Vector{Symbol}:
121+
:+
122+
:-
123+
:*
124+
:^
125+
:/
126+
:ifelse
127+
:atan
128+
:min
129+
:max
130+
```
84131
"""
85132
const DEFAULT_MULTIVARIATE_OPERATORS =
86133
[:+, :-, :*, :^, :/, :ifelse, :atan, :min, :max]
@@ -140,6 +187,19 @@ struct OperatorRegistry
140187
end
141188
end
142189

190+
function MOI.get(
191+
registry::OperatorRegistry,
192+
::MOI.ListOfSupportedNonlinearOperators,
193+
)
194+
ops = vcat(
195+
registry.univariate_operators,
196+
registry.multivariate_operators,
197+
registry.logic_operators,
198+
registry.comparison_operators,
199+
)
200+
return unique(ops)
201+
end
202+
143203
const _FORWARD_DIFF_METHOD_ERROR_HELPER = raw"""
144204
Common reasons for this include:
145205

src/Nonlinear/parse.jl

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,58 @@ function parse_expression(::Model, ::Expression, x::Any, ::Int)
3636
)
3737
end
3838

39+
function parse_expression(
40+
data::Model,
41+
expr::Expression,
42+
x::MOI.ScalarNonlinearFunction,
43+
parent_index::Int,
44+
)
45+
stack = Tuple{Int,Any}[(parent_index, x)]
46+
while !isempty(stack)
47+
parent_node, arg = pop!(stack)
48+
if arg isa MOI.ScalarNonlinearFunction
49+
_parse_without_recursion_inner(stack, data, expr, arg, parent_node)
50+
else
51+
# We can use recursion here, because ScalarNonlinearFunction only
52+
# occur in other ScalarNonlinearFunction.
53+
parse_expression(data, expr, arg, parent_node)
54+
end
55+
end
56+
return
57+
end
58+
59+
function _get_node_type(data, x)
60+
id = get(data.operators.univariate_operator_to_id, x.head, nothing)
61+
if length(x.args) == 1 && id !== nothing
62+
return id, MOI.Nonlinear.NODE_CALL_UNIVARIATE
63+
end
64+
id = get(data.operators.multivariate_operator_to_id, x.head, nothing)
65+
if id !== nothing
66+
return id, MOI.Nonlinear.NODE_CALL_MULTIVARIATE
67+
end
68+
id = get(data.operators.comparison_operator_to_id, x.head, nothing)
69+
if id !== nothing
70+
return id, MOI.Nonlinear.NODE_COMPARISON
71+
end
72+
id = get(data.operators.logic_operator_to_id, x.head, nothing)
73+
if id !== nothing
74+
return id, MOI.Nonlinear.NODE_LOGIC
75+
end
76+
return throw(MOI.UnsupportedNonlinearOperator(x.head))
77+
end
78+
79+
function _parse_without_recursion_inner(stack, data, expr, x, parent)
80+
id, node_type = _get_node_type(data, x)
81+
push!(expr.nodes, Node(node_type, id, parent))
82+
parent = length(expr.nodes)
83+
# Args need to be pushed onto the stack in reverse because the stack is a
84+
# first-in last-out datastructure.
85+
for arg in reverse(x.args)
86+
push!(stack, (parent, arg))
87+
end
88+
return
89+
end
90+
3991
function parse_expression(
4092
data::Model,
4193
expr::Expression,
@@ -108,7 +160,7 @@ function _parse_univariate_expression(
108160
_parse_multivariate_expression(stack, data, expr, x, parent_index)
109161
return
110162
end
111-
error("Unable to parse: $x")
163+
throw(MOI.UnsupportedNonlinearOperator(x.args[1]))
112164
end
113165
push!(expr.nodes, Node(NODE_CALL_UNIVARIATE, id, parent_index))
114166
push!(stack, (length(expr.nodes), x.args[2]))
@@ -200,6 +252,28 @@ function parse_expression(
200252
return
201253
end
202254

255+
function parse_expression(
256+
data::Model,
257+
expr::Expression,
258+
x::MOI.ScalarAffineFunction,
259+
parent_index::Int,
260+
)
261+
f = convert(MOI.ScalarNonlinearFunction, x)
262+
parse_expression(data, expr, f, parent_index)
263+
return
264+
end
265+
266+
function parse_expression(
267+
data::Model,
268+
expr::Expression,
269+
x::MOI.ScalarQuadraticFunction,
270+
parent_index::Int,
271+
)
272+
f = convert(MOI.ScalarNonlinearFunction, x)
273+
parse_expression(data, expr, f, parent_index)
274+
return
275+
end
276+
203277
function parse_expression(::Model, expr::Expression, x::Real, parent_index::Int)
204278
push!(expr.values, convert(Float64, x)::Float64)
205279
push!(expr.nodes, Node(NODE_VALUE, length(expr.values), parent_index))

src/Test/test_basic_constraint.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ function _function(
6666
)
6767
end
6868

69+
function _function(
70+
::Type{T},
71+
::Type{MOI.ScalarNonlinearFunction},
72+
x::Vector{MOI.VariableIndex},
73+
) where {T}
74+
return MOI.ScalarNonlinearFunction(
75+
:+,
76+
Any[MOI.ScalarNonlinearFunction(:^, Any[xi, 2]) for xi in x],
77+
)
78+
end
79+
6980
# Default fallback.
7081
_set(::Any, ::Type{S}) where {S} = _set(S)
7182

@@ -316,7 +327,12 @@ for s in [
316327
]
317328
S = getfield(MOI, s)
318329
functions = if S <: MOI.AbstractScalarSet
319-
(:VariableIndex, :ScalarAffineFunction, :ScalarQuadraticFunction)
330+
(
331+
:VariableIndex,
332+
:ScalarAffineFunction,
333+
:ScalarQuadraticFunction,
334+
:ScalarNonlinearFunction,
335+
)
320336
else
321337
(:VectorOfVariables, :VectorAffineFunction, :VectorQuadraticFunction)
322338
end

0 commit comments

Comments
 (0)