Skip to content

Commit

Permalink
Merge pull request #1423 from AayushSabharwal/as/fix-fastsub
Browse files Browse the repository at this point in the history
fix: fix `fast_substitute` on subarrays of symbolic variables
  • Loading branch information
ChrisRackauckas authored Jan 31, 2025
2 parents ec3bd4c + 16df6db commit 024e5a5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ end

SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x))

function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...;
conv = toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
Expand All @@ -141,6 +141,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
linenumbers = true,
cse = false,
nanmath = true,
wrap_code = (identity, identity),
kwargs...)

dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Expand All @@ -155,7 +156,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;

outsym = Symbol("ˍ₋out")
body = inplace_expr(unwrap(op), outsym)
oop_expr = conv(Func([outsym, dargs...], [], body), states)
iip_expr = conv(wrap_code[2](Func([outsym, dargs...], [], body)), states)

N = length(shape(op))
op = unwrap(op)
Expand All @@ -167,18 +168,18 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
$outsym
end) |> LiteralExpr
end
ip_expr = conv(Func(dargs, [], op_body), states)
oop_expr = conv(wrap_code[1](Func(dargs, [], op_body)), states)
if !checkbounds
@assert Meta.isexpr(oop_expr, :function)
oop_expr.args[2] = :(@inbounds begin; $(oop_expr.args[2]); end)
@assert Meta.isexpr(ip_expr, :function)
ip_expr.args[2] = :(@inbounds begin; $(ip_expr.args[2]); end)
@assert Meta.isexpr(iip_expr, :function)
iip_expr.args[2] = :(@inbounds begin; $(iip_expr.args[2]); end)
end
if expression == Val{true}
oop_expr, ip_expr
oop_expr, iip_expr
else
_build_and_inject_function(expression_module, oop_expr),
_build_and_inject_function(expression_module, ip_expr)
_build_and_inject_function(expression_module, iip_expr)
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x
x′ = fast_substitute(x, subs; operator)
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
x′
Expand Down Expand Up @@ -633,6 +634,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x
x′ = fast_substitute(x, pair; operator)
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
x′
Expand Down
11 changes: 11 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,14 @@ end
fn = build_function(T, collect(x); similarto = Array, expression = false)[1]
@test fn((1.0, 2.0)) [1.0, 4.0]
end

@testset "`build_function` with array symbolics" begin
@variables x[1:4]
for var in [x[1:2], x[1:2] .+ 0.0, Symbolics.unwrap(x[1:2])]
foop, fiip = build_function(var[1:2], x; expression = false)
@test foop(ones(4)) ones(2)
buf = zeros(2)
fiip(buf, ones(4))
@test buf ones(2)
end
end
8 changes: 8 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ end
ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0))
@test isequal(ex2, foo([x, 1.0], 2.0))
end

@testset "`fast_substitute` of subarray symbolics" begin
@variables p[1:4] q[1:5]
@test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], Dict()))
@test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], p => p))
@test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], Dict(p => q)))
@test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], p => q))
end

0 comments on commit 024e5a5

Please sign in to comment.