Skip to content

Commit

Permalink
fix: fix build_function for array symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 31, 2025
1 parent e0196cf commit 16df6db
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ end

SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x))

function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...;
conv = toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
Expand All @@ -141,6 +141,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
linenumbers = true,
cse = false,
nanmath = true,
wrap_code = (identity, identity),
kwargs...)

dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Expand All @@ -155,7 +156,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;

outsym = Symbol("ˍ₋out")
body = inplace_expr(unwrap(op), outsym)
oop_expr = conv(Func([outsym, dargs...], [], body), states)
iip_expr = conv(wrap_code[2](Func([outsym, dargs...], [], body)), states)

N = length(shape(op))
op = unwrap(op)
Expand All @@ -167,18 +168,18 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
$outsym
end) |> LiteralExpr
end
ip_expr = conv(Func(dargs, [], op_body), states)
oop_expr = conv(wrap_code[1](Func(dargs, [], op_body)), states)
if !checkbounds
@assert Meta.isexpr(oop_expr, :function)
oop_expr.args[2] = :(@inbounds begin; $(oop_expr.args[2]); end)
@assert Meta.isexpr(ip_expr, :function)
ip_expr.args[2] = :(@inbounds begin; $(ip_expr.args[2]); end)
@assert Meta.isexpr(iip_expr, :function)
iip_expr.args[2] = :(@inbounds begin; $(iip_expr.args[2]); end)
end
if expression == Val{true}
oop_expr, ip_expr
oop_expr, iip_expr
else
_build_and_inject_function(expression_module, oop_expr),
_build_and_inject_function(expression_module, ip_expr)
_build_and_inject_function(expression_module, iip_expr)
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,14 @@ end
fn = build_function(T, collect(x); similarto = Array, expression = false)[1]
@test fn((1.0, 2.0)) [1.0, 4.0]
end

@testset "`build_function` with array symbolics" begin
@variables x[1:4]
for var in [x[1:2], x[1:2] .+ 0.0, Symbolics.unwrap(x[1:2])]
foop, fiip = build_function(var[1:2], x; expression = false)
@test foop(ones(4)) ones(2)
buf = zeros(2)
fiip(buf, ones(4))
@test buf ones(2)
end
end

0 comments on commit 16df6db

Please sign in to comment.