From e0196cfdaf6f5a1a6cf8136345d013a5903df399 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 Jan 2025 20:44:44 +0530 Subject: [PATCH 1/2] fix: fix `fast_substitute` on subarrays of symbolic variables --- src/variable.jl | 2 ++ test/utils.jl | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/src/variable.jl b/src/variable.jl index 37ba40271..9d9cc260a 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -606,6 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing) canfold = Ref(!(op isa Symbolic)) args = let canfold = canfold map(args) do x + symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x x′ = fast_substitute(x, subs; operator) canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ @@ -633,6 +634,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) canfold = Ref(!(op isa Symbolic)) args = let canfold = canfold map(args) do x + symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x x′ = fast_substitute(x, pair; operator) canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ diff --git a/test/utils.jl b/test/utils.jl index 8203b41f6..a4f3db467 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -162,3 +162,11 @@ end ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0)) @test isequal(ex2, foo([x, 1.0], 2.0)) end + +@testset "`fast_substitute` of subarray symbolics" begin + @variables p[1:4] q[1:5] + @test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], Dict())) + @test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], p => p)) + @test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], Dict(p => q))) + @test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], p => q)) +end From 16df6db70e586f5d2bee87dfc487a532c2216d15 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 31 Jan 2025 12:58:41 +0530 Subject: [PATCH 2/2] fix: fix `build_function` for array symbolics --- src/build_function.jl | 15 ++++++++------- test/build_function.jl | 11 +++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 34d53481e..0a6c152dc 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -132,7 +132,7 @@ end SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x)) -function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; +function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...; conv = toexpr, expression = Val{true}, expression_module = @__MODULE__(), @@ -141,6 +141,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; linenumbers = true, cse = false, nanmath = true, + wrap_code = (identity, identity), kwargs...) dargs = map((x) -> destructure_arg(x[2], !checkbounds, @@ -155,7 +156,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; outsym = Symbol("ˍ₋out") body = inplace_expr(unwrap(op), outsym) - oop_expr = conv(Func([outsym, dargs...], [], body), states) + iip_expr = conv(wrap_code[2](Func([outsym, dargs...], [], body)), states) N = length(shape(op)) op = unwrap(op) @@ -167,18 +168,18 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; $outsym end) |> LiteralExpr end - ip_expr = conv(Func(dargs, [], op_body), states) + oop_expr = conv(wrap_code[1](Func(dargs, [], op_body)), states) if !checkbounds @assert Meta.isexpr(oop_expr, :function) oop_expr.args[2] = :(@inbounds begin; $(oop_expr.args[2]); end) - @assert Meta.isexpr(ip_expr, :function) - ip_expr.args[2] = :(@inbounds begin; $(ip_expr.args[2]); end) + @assert Meta.isexpr(iip_expr, :function) + iip_expr.args[2] = :(@inbounds begin; $(iip_expr.args[2]); end) end if expression == Val{true} - oop_expr, ip_expr + oop_expr, iip_expr else _build_and_inject_function(expression_module, oop_expr), - _build_and_inject_function(expression_module, ip_expr) + _build_and_inject_function(expression_module, iip_expr) end end diff --git a/test/build_function.jl b/test/build_function.jl index 0c09fdd6d..f7df7b763 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -290,3 +290,14 @@ end fn = build_function(T, collect(x); similarto = Array, expression = false)[1] @test fn((1.0, 2.0)) ≈ [1.0, 4.0] end + +@testset "`build_function` with array symbolics" begin + @variables x[1:4] + for var in [x[1:2], x[1:2] .+ 0.0, Symbolics.unwrap(x[1:2])] + foop, fiip = build_function(var[1:2], x; expression = false) + @test foop(ones(4)) ≈ ones(2) + buf = zeros(2) + fiip(buf, ones(4)) + @test buf ≈ ones(2) + end +end