Skip to content

Commit fc753e0

Browse files
committed
WIP moshi
1 parent 4d86901 commit fc753e0

15 files changed

+437
-199
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
1616
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1717
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1818
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
1920
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
2021
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2122
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2526
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2627
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2728
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
28-
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
2929
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
3030

3131
[weakdeps]
@@ -48,6 +48,7 @@ DocStringExtensions = "0.8, 0.9"
4848
DynamicPolynomials = "0.5, 0.6"
4949
IfElse = "0.1"
5050
LabelledArrays = "1.5"
51+
Moshi = "0.3.5"
5152
MultivariatePolynomials = "0.5"
5253
NaNMath = "0.3, 1"
5354
ReverseDiff = "1"
@@ -57,7 +58,6 @@ StaticArrays = "0.12, 1.0"
5758
SymbolicIndexingInterface = "0.3"
5859
TermInterface = "2.0"
5960
TimerOutputs = "0.5"
60-
Unityper = "0.1.2"
6161
WeakValueDicts = "0.1.0"
6262
julia = "1.3"
6363

src/SymbolicUtils.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using DocStringExtensions
77

88
export @syms, term, showraw, hasmetadata, getmetadata, setmetadata
99

10-
using Unityper
10+
using Moshi.Data: @data, data_type_name, variant_name
11+
using Moshi.Match: @match
1112
using TermInterface
1213
using DataStructures
1314
using Setfield

src/code.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
88

99
import ..SymbolicUtils
1010
import ..SymbolicUtils.Rewriters
11-
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
12-
symtype, sorted_arguments, metadata, isterm, term, maketerm
11+
import SymbolicUtils: @matchable, BasicSymbolicType, Sym, Term, iscall, operation, arguments, issym,
12+
isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm
1313
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1414

1515
##== state management ==##
@@ -156,7 +156,7 @@ function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
156156
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
157157
end
158158

159-
function function_to_expr(x::BasicSymbolic, O, st)
159+
function function_to_expr(x::BasicSymbolicType, O, st)
160160
issym(x) ? get(st.rewrites, O, nothing) : nothing
161161
end
162162

@@ -182,6 +182,8 @@ function toexpr(O, st)
182182
if issym(O)
183183
O = substitute_name(O, st)
184184
return issym(O) ? nameof(O) : toexpr(O, st)
185+
elseif isconst(O)
186+
return toexpr(O.val, st)
185187
end
186188
O = substitute_name(O, st)
187189

@@ -766,7 +768,7 @@ end
766768
function cse_block(state, t, name=Symbol("var-", hash(t)))
767769
assignments = Assignment[]
768770
counter = Ref{Int}(1)
769-
names = Dict{Any, BasicSymbolic}()
771+
names = Dict{Any, BasicSymbolicType}()
770772
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
771773
end
772774

src/inspect.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function AbstractTrees.nodevalue(x::Symbolic)
55
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
66
end
77

8-
function AbstractTrees.nodevalue(x::BasicSymbolic)
8+
function AbstractTrees.nodevalue(x::BasicSymbolicType)
99
str = if !iscall(x)
1010
string(exprtype(x), "(", x, ")")
1111
elseif isadd(x)

src/matchers.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,23 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88
function matcher(val::Any)
9-
iscall(val) && return term_matcher(val)
9+
if isconst(val)
10+
slot = val.val
11+
return matcher(slot)
12+
elseif iscall(val)
13+
return term_matcher(val)
14+
end
1015
function literal_matcher(next, data, bindings)
11-
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
16+
if islist(data)
17+
cd = car(data)
18+
if isconst(cd)
19+
cd = cd.val
20+
end
21+
if isequal(cd, val)
22+
return next(bindings, 1)
23+
end
24+
end
25+
nothing
1226
end
1327
end
1428

src/methods.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ macro number_methods(T, rhs1, rhs2, options=nothing)
9595
number_methods(T, rhs1, rhs2, options) |> esc
9696
end
9797

98-
@number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics)
99-
@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
98+
@number_methods(BasicSymbolicType{<:Number}, term(f, a), term(f, a, b), skipbasics)
99+
@number_methods(BasicSymbolicType{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
100100

101101
for f in vcat(diadic, [+, -, *, \, /, ^])
102102
@eval promote_symtype(::$(typeof(f)),
@@ -188,7 +188,7 @@ end
188188
for f in [!, ~]
189189
@eval begin
190190
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
191-
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
191+
(::$(typeof(f)))(s::Symbolic{Bool}) = isconst(s) ? !s.val : Term{Bool}(!, [s])
192192
end
193193
end
194194

src/ordering.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function get_degrees(expr)
2727
elseif iscall(expr)
2828
op = operation(expr)
2929
args = sorted_arguments(expr)
30-
if op == (^) && args[2] isa Number
30+
if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].val isa Number))
3131
return map(get_degrees(args[1])) do (base, pow)
3232
(base => pow * args[2])
3333
end
@@ -78,13 +78,16 @@ function <ₑ(a::Tuple, b::Tuple)
7878
return length(a) < length(b)
7979
end
8080

81-
function <(a::BasicSymbolic, b::BasicSymbolic)
81+
function <(a::BasicSymbolicType, b::BasicSymbolicType)
82+
isconst(a) && isconst(b) && return a.val <ₑ b.val
83+
isconst(a) && return a.val <ₑ b
84+
isconst(b) && return a <ₑ b.val
8285
da, db = get_degrees(a), get_degrees(b)
8386
fw = monomial_lt(da, db)
8487
bw = monomial_lt(db, da)
8588
if fw === bw && !isequal(a, b)
8689
if _arglen(a) == _arglen(b)
87-
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
90+
return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...)
8891
else
8992
return _arglen(a) < _arglen(b)
9093
end

src/polyform.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ PolyForm(sin((x+y)^2), recurse=true) #=> sin((x^2 + (2x)y + y^2))
2929
struct PolyForm{T} <: Symbolic{T}
3030
p::MP.AbstractPolynomialLike
3131
pvar2sym::Bijection{Any,Any} # @polyvar x --> @sym x etc.
32-
sym2term::Dict{BasicSymbolic,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...))
32+
sym2term::Dict{BasicSymbolicType,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...))
3333
metadata
3434
function (::Type{PolyForm{T}})(p, d1, d2, m=nothing) where {T}
3535
p isa Number && return p
@@ -63,7 +63,7 @@ end
6363
function get_sym2term()
6464
v = SYM2TERM[].value
6565
if v === nothing
66-
d = Dict{BasicSymbolic,Any}()
66+
d = Dict{BasicSymbolicType,Any}()
6767
SYM2TERM[] = WeakRef(d)
6868
return d
6969
else
@@ -95,6 +95,7 @@ end
9595
_isone(p::PolyForm) = isone(p.p)
9696

9797
function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
98+
x = isconst(x) ? x.val : x
9899
if x isa Number
99100
return x
100101
elseif iscall(x)

src/substitute.jl

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ function substitute(expr, dict; fold=true)
2222
canfold = !(op isa Symbolic)
2323
args = map(arguments(expr)) do x
2424
x′ = substitute(x, dict; fold=fold)
25+
x′ = isconst(x) ? x′.val : x′
2526
canfold = canfold && !(x′ isa Symbolic)
2627
x′
2728
end
@@ -54,6 +55,7 @@ function _occursin(needle, haystack)
5455
if iscall(haystack)
5556
args = arguments(haystack)
5657
for arg in args
58+
arg = isconst(arg) ? arg.val : arg
5759
if needle isa Integer || needle isa AbstractFloat
5860
isequal(needle, arg) && return true
5961
else

0 commit comments

Comments
 (0)