Skip to content

Commit 3d15b12

Browse files
committed
feat: get parse_expressions to work with aliases
1 parent 2c8d3ca commit 3d15b12

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "2.3.0"
4+
version = "2.4.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Parse.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ..ExpressionModule:
1212
get_operators,
1313
get_variable_names,
1414
node_type
15+
using ..ExpressionAlgebraModule: declare_operator_alias
1516

1617
"""
1718
@parse_expression(expr; operators, variable_names, node_type=Node, evaluate_on=[])
@@ -331,8 +332,13 @@ end
331332
kws...,
332333
)::N where {F<:Function,N<:AbstractExpressionNode,E<:AbstractExpression}
333334
degree = length(args) - 1
334-
if degree <= length(operators.ops) && func operators[degree]
335-
op_idx = findfirst(==(func), operators[degree])
335+
if degree <= length(operators.ops) && (
336+
op_idx = findfirst(
337+
op -> op == func || declare_operator_alias(op, Val(degree)) == func,
338+
operators[degree],
339+
);
340+
!isnothing(op_idx)
341+
)
336342
return N(;
337343
op=op_idx::Int,
338344
children=map(
@@ -342,8 +348,18 @@ end
342348
(args[2:end]...,),
343349
),
344350
)
345-
elseif degree > 2 && func (+, -, *) && func operators[2]
346-
op_idx = findfirst(==(func), operators[2])::Int
351+
end
352+
353+
# Handle chaining for +, -, * operators
354+
if degree > 2 &&
355+
func (+, -, *) &&
356+
(
357+
op_idx = findfirst(
358+
op -> op == func || declare_operator_alias(op, Val(2)) == func,
359+
operators[degree],
360+
);
361+
!isnothing(op_idx)
362+
)
347363
inner = N(;
348364
op=op_idx::Int,
349365
children=(

test/test_parse.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,40 @@
4242
end
4343
end
4444

45+
@testitem "Parse with operator aliases" begin
46+
using DynamicExpressions
47+
using DynamicExpressions: DynamicExpressions as DE
48+
using Test
49+
50+
## UNARY
51+
safe_sqrt(x) = x < 0 ? convert(typeof(x), NaN) : sqrt(x)
52+
DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt
53+
54+
operators = OperatorEnum(1 => [safe_sqrt, sin, cos], 2 => [+, -, *, /])
55+
56+
ex = parse_expression(
57+
"sqrt(x) + sin(y)"; operators=operators, variable_names=["x", "y"]
58+
)
59+
60+
@test typeof(ex) <: Expression
61+
@test ex.tree.op == 1
62+
@test ex.tree.children[1].x.op == 1
63+
@test ex.tree.children[2].x.op == 2
64+
65+
## BINARY
66+
safe_pow(x, y) = x < 0 && y != round(y) ? NaN : x^y
67+
DE.declare_operator_alias(::typeof(safe_pow), ::Val{2}) = ^
68+
69+
operators = OperatorEnum(1 => [sin], 2 => [+, -, safe_pow, *])
70+
ex = parse_expression("x^2 + sin(y)"; operators=operators, variable_names=["x", "y"])
71+
72+
@test typeof(ex) <: Expression
73+
@test ex.tree.op == 1
74+
@test ex.tree.children[1].x.op == 3 # safe_pow
75+
@test ex.tree.children[1].x.children[2].x.val == 2.0
76+
@test ex.tree.children[2].x.op == 1
77+
end
78+
4579
@testitem "Can also parse just a float" begin
4680
using DynamicExpressions
4781
operators = OperatorEnum() # Tests empty operators

0 commit comments

Comments
 (0)