Skip to content

Commit 7017514

Browse files
Remove eval and rework primitive extraction
1 parent 05fa4b5 commit 7017514

21 files changed

+668
-409
lines changed

docs/changelog.md

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ _This project uses semantic versioning_
2424
- Updates function constructor to remove `default` and `on_merge`. You also can't set a `cost` when you use a `merge`
2525
function or return a primitive.
2626
- `eq` now only takes two args, instead of being able to compare any number of values.
27+
- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives.
28+
- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`.
29+
- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions.
2730

2831
## 8.0.1 (2024-10-24)
2932

python/egglog/bindings.pyi

+3-6
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,14 @@ class SerializedEGraph:
9292
class PyObjectSort:
9393
def __init__(self) -> None: ...
9494
def store(self, __o: object, /) -> _Expr: ...
95+
def load(self, __e: _Expr, /) -> object: ...
9596

9697
@final
9798
class EGraph:
9899
def __init__(
99100
self,
100-
__py_object_sort: PyObjectSort | None = None,
101+
py_object_sort: PyObjectSort | None = None,
102+
/,
101103
*,
102104
fact_directory: str | Path | None = None,
103105
seminaive: bool = True,
@@ -116,11 +118,6 @@ class EGraph:
116118
max_calls_per_function: int | None = None,
117119
include_temporary_functions: bool = False,
118120
) -> SerializedEGraph: ...
119-
def eval_py_object(self, __expr: _Expr) -> object: ...
120-
def eval_i64(self, __expr: _Expr) -> int: ...
121-
def eval_f64(self, __expr: _Expr) -> float: ...
122-
def eval_string(self, __expr: _Expr) -> str: ...
123-
def eval_bool(self, __expr: _Expr) -> bool: ...
124121

125122
@final
126123
class EggSmolError(Exception):

python/egglog/builtins.py

+182-4
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55

66
from __future__ import annotations
77

8+
from fractions import Fraction
89
from functools import partial, reduce
9-
from types import FunctionType
10+
from types import FunctionType, MethodType
1011
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
1112

1213
from typing_extensions import TypeVarTuple, Unpack
1314

15+
from . import bindings
1416
from .conversion import convert, converter, get_type_args
15-
from .egraph import BaseExpr, BuiltinExpr, Unit, function, get_current_ruleset, method
17+
from .declarations import *
18+
from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19+
from .egraph_state import GLOBAL_PY_OBJECT_SORT
1620
from .functionalize import functionalize
1721
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
1822
from .thunk import Thunk
1923

2024
if TYPE_CHECKING:
21-
from collections.abc import Callable
25+
from collections.abc import Callable, Iterator
2226

2327

2428
__all__ = [
@@ -32,6 +36,7 @@
3236
"SetLike",
3337
"String",
3438
"StringLike",
39+
"Unit",
3540
"UnstableFn",
3641
"Vec",
3742
"VecLike",
@@ -46,7 +51,25 @@
4651
]
4752

4853

54+
class Unit(BuiltinExpr, egg_sort="Unit"):
55+
"""
56+
The unit type. This is used to reprsent if a value exists in the e-graph or not.
57+
"""
58+
59+
def __init__(self) -> None: ...
60+
61+
@method(preserve=True)
62+
def __bool__(self) -> bool:
63+
return bool(expr_fact(self))
64+
65+
4966
class String(BuiltinExpr):
67+
@method(preserve=True)
68+
def eval(self) -> str:
69+
value = _extract_lit(self)
70+
assert isinstance(value, bindings.String)
71+
return value.value
72+
5073
def __init__(self, value: str) -> None: ...
5174

5275
@method(egg_fn="replace")
@@ -62,10 +85,20 @@ def join(*strings: StringLike) -> String: ...
6285

6386
converter(str, String, String)
6487

65-
BoolLike = Union["Bool", bool]
88+
BoolLike: TypeAlias = Union["Bool", bool]
6689

6790

6891
class Bool(BuiltinExpr, egg_sort="bool"):
92+
@method(preserve=True)
93+
def eval(self) -> bool:
94+
value = _extract_lit(self)
95+
assert isinstance(value, bindings.Bool)
96+
return value.value
97+
98+
@method(preserve=True)
99+
def __bool__(self) -> bool:
100+
return self.eval()
101+
69102
def __init__(self, value: bool) -> None: ...
70103

71104
@method(egg_fn="not")
@@ -91,6 +124,20 @@ def implies(self, other: BoolLike) -> Bool: ...
91124

92125

93126
class i64(BuiltinExpr): # noqa: N801
127+
@method(preserve=True)
128+
def eval(self) -> int:
129+
value = _extract_lit(self)
130+
assert isinstance(value, bindings.Int)
131+
return value.value
132+
133+
@method(preserve=True)
134+
def __index__(self) -> int:
135+
return self.eval()
136+
137+
@method(preserve=True)
138+
def __int__(self) -> int:
139+
return self.eval()
140+
94141
def __init__(self, value: int) -> None: ...
95142

96143
@method(egg_fn="+")
@@ -193,6 +240,20 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
193240

194241

195242
class f64(BuiltinExpr): # noqa: N801
243+
@method(preserve=True)
244+
def eval(self) -> float:
245+
value = _extract_lit(self)
246+
assert isinstance(value, bindings.Float)
247+
return value.value
248+
249+
@method(preserve=True)
250+
def __float__(self) -> float:
251+
return self.eval()
252+
253+
@method(preserve=True)
254+
def __int__(self) -> int:
255+
return int(self.eval())
256+
196257
def __init__(self, value: float) -> None: ...
197258

198259
@method(egg_fn="neg")
@@ -265,6 +326,33 @@ def to_string(self) -> String: ...
265326

266327

267328
class Map(BuiltinExpr, Generic[T, V]):
329+
@method(preserve=True)
330+
def eval(self) -> dict[T, V]:
331+
call = _extract_call(self)
332+
expr = cast(RuntimeExpr, self)
333+
d = {}
334+
while call.callable != ClassMethodRef("Map", "empty"):
335+
assert call.callable == MethodRef("Map", "insert")
336+
call_typed, k_typed, v_typed = call.args
337+
assert isinstance(call_typed.expr, CallDecl)
338+
k = cast(T, expr.__with_expr__(k_typed))
339+
v = cast(V, expr.__with_expr__(v_typed))
340+
d[k] = v
341+
call = call_typed.expr
342+
return d
343+
344+
@method(preserve=True)
345+
def __iter__(self) -> Iterator[T]:
346+
return iter(self.eval())
347+
348+
@method(preserve=True)
349+
def __len__(self) -> int:
350+
return len(self.eval())
351+
352+
@method(preserve=True)
353+
def __contains__(self, key: T) -> bool:
354+
return key in self.eval()
355+
268356
@method(egg_fn="map-empty")
269357
@classmethod
270358
def empty(cls) -> Map[T, V]: ...
@@ -305,6 +393,24 @@ def rebuild(self) -> Map[T, V]: ...
305393

306394

307395
class Set(BuiltinExpr, Generic[T]):
396+
@method(preserve=True)
397+
def eval(self) -> set[T]:
398+
call = _extract_call(self)
399+
assert call.callable == InitRef("Set")
400+
return {cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args}
401+
402+
@method(preserve=True)
403+
def __iter__(self) -> Iterator[T]:
404+
return iter(self.eval())
405+
406+
@method(preserve=True)
407+
def __len__(self) -> int:
408+
return len(self.eval())
409+
410+
@method(preserve=True)
411+
def __contains__(self, key: T) -> bool:
412+
return key in self.eval()
413+
308414
@method(egg_fn="set-of")
309415
def __init__(self, *args: T) -> None: ...
310416

@@ -349,6 +455,28 @@ def rebuild(self) -> Set[T]: ...
349455

350456

351457
class Rational(BuiltinExpr):
458+
@method(preserve=True)
459+
def eval(self) -> Fraction:
460+
call = _extract_call(self)
461+
assert call.callable == InitRef("Rational")
462+
463+
def _to_int(e: TypedExprDecl) -> int:
464+
expr = e.expr
465+
assert isinstance(expr, LitDecl)
466+
assert isinstance(expr.value, int)
467+
return expr.value
468+
469+
num, den = call.args
470+
return Fraction(_to_int(num), _to_int(den))
471+
472+
@method(preserve=True)
473+
def __float__(self) -> float:
474+
return float(self.eval())
475+
476+
@method(preserve=True)
477+
def __int__(self) -> int:
478+
return int(self.eval())
479+
352480
@method(egg_fn="rational")
353481
def __init__(self, num: i64Like, den: i64Like) -> None: ...
354482

@@ -410,6 +538,26 @@ def denom(self) -> i64: ...
410538

411539

412540
class Vec(BuiltinExpr, Generic[T]):
541+
@method(preserve=True)
542+
def eval(self) -> tuple[T, ...]:
543+
call = _extract_call(self)
544+
if call.callable == ClassMethodRef("Vec", "empty"):
545+
return ()
546+
assert call.callable == InitRef("Vec")
547+
return tuple(cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args)
548+
549+
@method(preserve=True)
550+
def __iter__(self) -> Iterator[T]:
551+
return iter(self.eval())
552+
553+
@method(preserve=True)
554+
def __len__(self) -> int:
555+
return len(self.eval())
556+
557+
@method(preserve=True)
558+
def __contains__(self, key: T) -> bool:
559+
return key in self.eval()
560+
413561
@method(egg_fn="vec-of")
414562
def __init__(self, *args: T) -> None: ...
415563

@@ -461,6 +609,13 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
461609

462610

463611
class PyObject(BuiltinExpr):
612+
@method(preserve=True)
613+
def eval(self) -> object:
614+
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, self), 0)
615+
assert isinstance(report, bindings.Best)
616+
expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
617+
return GLOBAL_PY_OBJECT_SORT.load(expr)
618+
464619
def __init__(self, value: object) -> None: ...
465620

466621
@method(egg_fn="py-from-string")
@@ -554,6 +709,8 @@ def __init__(self, f, *partial) -> None: ...
554709
def __call__(self, *args: Unpack[TS]) -> T: ...
555710

556711

712+
# Method Type is for builtins like __getitem__
713+
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
557714
converter(RuntimeFunction, UnstableFn, UnstableFn)
558715
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
559716

@@ -590,3 +747,24 @@ def value_to_annotation(a: object) -> type | None:
590747

591748

592749
converter(FunctionType, UnstableFn, _convert_function)
750+
751+
752+
def _extract_lit(e: BaseExpr) -> bindings._Literal:
753+
"""
754+
Special case extracting literals to make this faster by using termdag directly.
755+
"""
756+
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, e), 0)
757+
assert isinstance(report, bindings.Best)
758+
term = report.term
759+
assert isinstance(term, bindings.TermLit)
760+
return term.value
761+
762+
763+
def _extract_call(e: BaseExpr) -> CallDecl:
764+
"""
765+
Extracts the call form of an expression
766+
"""
767+
extracted = cast(RuntimeExpr, (EGraph.current or EGraph()).extract(e))
768+
expr = extracted.__egg_typed_expr__.expr
769+
assert isinstance(expr, CallDecl)
770+
return expr

python/egglog/conversion.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .egraph import BaseExpr
1818

19-
__all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal", "ConvertError"]
19+
__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
2020
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
2121
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
2222
# Global declerations to store all convertable types so we can query if they have certain methods or not
@@ -153,9 +153,9 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
153153
b_converts_to = {
154154
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
155155
}
156-
if isinstance(a_tp, JustTypeRef):
156+
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
157157
a_converts_to[a_tp] = 0
158-
if isinstance(b_tp, JustTypeRef):
158+
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
159159
b_converts_to[b_tp] = 0
160160
common = set(a_converts_to) & set(b_converts_to)
161161
if not common:

python/egglog/declarations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
196196
return self._classes[class_name].properties[property_name]
197197
case InitRef(class_name):
198198
init_fn = self._classes[class_name].init
199-
assert init_fn
199+
assert init_fn, f"Class {class_name} does not have an init function."
200200
return init_fn
201201
case UnnamedFunctionRef():
202202
return ConstructorDecl(ref.signature)

0 commit comments

Comments
 (0)