5
5
6
6
from __future__ import annotations
7
7
8
+ from fractions import Fraction
8
9
from functools import partial , reduce
9
- from types import FunctionType
10
+ from types import FunctionType , MethodType
10
11
from typing import TYPE_CHECKING , Generic , Protocol , TypeAlias , TypeVar , Union , cast , overload
11
12
12
13
from typing_extensions import TypeVarTuple , Unpack
13
14
15
+ from . import bindings
14
16
from .conversion import convert , converter , get_type_args
15
- from .egraph import BaseExpr , BuiltinExpr , Unit , function , get_current_ruleset , method
17
+ from .declarations import *
18
+ from .egraph import BaseExpr , BuiltinExpr , EGraph , expr_fact , function , get_current_ruleset , method
19
+ from .egraph_state import GLOBAL_PY_OBJECT_SORT
16
20
from .functionalize import functionalize
17
21
from .runtime import RuntimeClass , RuntimeExpr , RuntimeFunction
18
22
from .thunk import Thunk
19
23
20
24
if TYPE_CHECKING :
21
- from collections .abc import Callable
25
+ from collections .abc import Callable , Iterator
22
26
23
27
24
28
__all__ = [
32
36
"SetLike" ,
33
37
"String" ,
34
38
"StringLike" ,
39
+ "Unit" ,
35
40
"UnstableFn" ,
36
41
"Vec" ,
37
42
"VecLike" ,
46
51
]
47
52
48
53
54
+ class Unit (BuiltinExpr , egg_sort = "Unit" ):
55
+ """
56
+ The unit type. This is used to reprsent if a value exists in the e-graph or not.
57
+ """
58
+
59
+ def __init__ (self ) -> None : ...
60
+
61
+ @method (preserve = True )
62
+ def __bool__ (self ) -> bool :
63
+ return bool (expr_fact (self ))
64
+
65
+
49
66
class String (BuiltinExpr ):
67
+ @method (preserve = True )
68
+ def eval (self ) -> str :
69
+ value = _extract_lit (self )
70
+ assert isinstance (value , bindings .String )
71
+ return value .value
72
+
50
73
def __init__ (self , value : str ) -> None : ...
51
74
52
75
@method (egg_fn = "replace" )
@@ -62,10 +85,20 @@ def join(*strings: StringLike) -> String: ...
62
85
63
86
converter (str , String , String )
64
87
65
- BoolLike = Union ["Bool" , bool ]
88
+ BoolLike : TypeAlias = Union ["Bool" , bool ]
66
89
67
90
68
91
class Bool (BuiltinExpr , egg_sort = "bool" ):
92
+ @method (preserve = True )
93
+ def eval (self ) -> bool :
94
+ value = _extract_lit (self )
95
+ assert isinstance (value , bindings .Bool )
96
+ return value .value
97
+
98
+ @method (preserve = True )
99
+ def __bool__ (self ) -> bool :
100
+ return self .eval ()
101
+
69
102
def __init__ (self , value : bool ) -> None : ...
70
103
71
104
@method (egg_fn = "not" )
@@ -91,6 +124,20 @@ def implies(self, other: BoolLike) -> Bool: ...
91
124
92
125
93
126
class i64 (BuiltinExpr ): # noqa: N801
127
+ @method (preserve = True )
128
+ def eval (self ) -> int :
129
+ value = _extract_lit (self )
130
+ assert isinstance (value , bindings .Int )
131
+ return value .value
132
+
133
+ @method (preserve = True )
134
+ def __index__ (self ) -> int :
135
+ return self .eval ()
136
+
137
+ @method (preserve = True )
138
+ def __int__ (self ) -> int :
139
+ return self .eval ()
140
+
94
141
def __init__ (self , value : int ) -> None : ...
95
142
96
143
@method (egg_fn = "+" )
@@ -193,6 +240,20 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
193
240
194
241
195
242
class f64 (BuiltinExpr ): # noqa: N801
243
+ @method (preserve = True )
244
+ def eval (self ) -> float :
245
+ value = _extract_lit (self )
246
+ assert isinstance (value , bindings .Float )
247
+ return value .value
248
+
249
+ @method (preserve = True )
250
+ def __float__ (self ) -> float :
251
+ return self .eval ()
252
+
253
+ @method (preserve = True )
254
+ def __int__ (self ) -> int :
255
+ return int (self .eval ())
256
+
196
257
def __init__ (self , value : float ) -> None : ...
197
258
198
259
@method (egg_fn = "neg" )
@@ -265,6 +326,33 @@ def to_string(self) -> String: ...
265
326
266
327
267
328
class Map (BuiltinExpr , Generic [T , V ]):
329
+ @method (preserve = True )
330
+ def eval (self ) -> dict [T , V ]:
331
+ call = _extract_call (self )
332
+ expr = cast (RuntimeExpr , self )
333
+ d = {}
334
+ while call .callable != ClassMethodRef ("Map" , "empty" ):
335
+ assert call .callable == MethodRef ("Map" , "insert" )
336
+ call_typed , k_typed , v_typed = call .args
337
+ assert isinstance (call_typed .expr , CallDecl )
338
+ k = cast (T , expr .__with_expr__ (k_typed ))
339
+ v = cast (V , expr .__with_expr__ (v_typed ))
340
+ d [k ] = v
341
+ call = call_typed .expr
342
+ return d
343
+
344
+ @method (preserve = True )
345
+ def __iter__ (self ) -> Iterator [T ]:
346
+ return iter (self .eval ())
347
+
348
+ @method (preserve = True )
349
+ def __len__ (self ) -> int :
350
+ return len (self .eval ())
351
+
352
+ @method (preserve = True )
353
+ def __contains__ (self , key : T ) -> bool :
354
+ return key in self .eval ()
355
+
268
356
@method (egg_fn = "map-empty" )
269
357
@classmethod
270
358
def empty (cls ) -> Map [T , V ]: ...
@@ -305,6 +393,24 @@ def rebuild(self) -> Map[T, V]: ...
305
393
306
394
307
395
class Set (BuiltinExpr , Generic [T ]):
396
+ @method (preserve = True )
397
+ def eval (self ) -> set [T ]:
398
+ call = _extract_call (self )
399
+ assert call .callable == InitRef ("Set" )
400
+ return {cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args }
401
+
402
+ @method (preserve = True )
403
+ def __iter__ (self ) -> Iterator [T ]:
404
+ return iter (self .eval ())
405
+
406
+ @method (preserve = True )
407
+ def __len__ (self ) -> int :
408
+ return len (self .eval ())
409
+
410
+ @method (preserve = True )
411
+ def __contains__ (self , key : T ) -> bool :
412
+ return key in self .eval ()
413
+
308
414
@method (egg_fn = "set-of" )
309
415
def __init__ (self , * args : T ) -> None : ...
310
416
@@ -349,6 +455,28 @@ def rebuild(self) -> Set[T]: ...
349
455
350
456
351
457
class Rational (BuiltinExpr ):
458
+ @method (preserve = True )
459
+ def eval (self ) -> Fraction :
460
+ call = _extract_call (self )
461
+ assert call .callable == InitRef ("Rational" )
462
+
463
+ def _to_int (e : TypedExprDecl ) -> int :
464
+ expr = e .expr
465
+ assert isinstance (expr , LitDecl )
466
+ assert isinstance (expr .value , int )
467
+ return expr .value
468
+
469
+ num , den = call .args
470
+ return Fraction (_to_int (num ), _to_int (den ))
471
+
472
+ @method (preserve = True )
473
+ def __float__ (self ) -> float :
474
+ return float (self .eval ())
475
+
476
+ @method (preserve = True )
477
+ def __int__ (self ) -> int :
478
+ return int (self .eval ())
479
+
352
480
@method (egg_fn = "rational" )
353
481
def __init__ (self , num : i64Like , den : i64Like ) -> None : ...
354
482
@@ -410,6 +538,26 @@ def denom(self) -> i64: ...
410
538
411
539
412
540
class Vec (BuiltinExpr , Generic [T ]):
541
+ @method (preserve = True )
542
+ def eval (self ) -> tuple [T , ...]:
543
+ call = _extract_call (self )
544
+ if call .callable == ClassMethodRef ("Vec" , "empty" ):
545
+ return ()
546
+ assert call .callable == InitRef ("Vec" )
547
+ return tuple (cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args )
548
+
549
+ @method (preserve = True )
550
+ def __iter__ (self ) -> Iterator [T ]:
551
+ return iter (self .eval ())
552
+
553
+ @method (preserve = True )
554
+ def __len__ (self ) -> int :
555
+ return len (self .eval ())
556
+
557
+ @method (preserve = True )
558
+ def __contains__ (self , key : T ) -> bool :
559
+ return key in self .eval ()
560
+
413
561
@method (egg_fn = "vec-of" )
414
562
def __init__ (self , * args : T ) -> None : ...
415
563
@@ -461,6 +609,13 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
461
609
462
610
463
611
class PyObject (BuiltinExpr ):
612
+ @method (preserve = True )
613
+ def eval (self ) -> object :
614
+ report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , self ), 0 )
615
+ assert isinstance (report , bindings .Best )
616
+ expr = report .termdag .term_to_expr (report .term , bindings .PanicSpan ())
617
+ return GLOBAL_PY_OBJECT_SORT .load (expr )
618
+
464
619
def __init__ (self , value : object ) -> None : ...
465
620
466
621
@method (egg_fn = "py-from-string" )
@@ -554,6 +709,8 @@ def __init__(self, f, *partial) -> None: ...
554
709
def __call__ (self , * args : Unpack [TS ]) -> T : ...
555
710
556
711
712
+ # Method Type is for builtins like __getitem__
713
+ converter (MethodType , UnstableFn , lambda m : UnstableFn (m .__func__ , m .__self__ ))
557
714
converter (RuntimeFunction , UnstableFn , UnstableFn )
558
715
converter (partial , UnstableFn , lambda p : UnstableFn (p .func , * p .args ))
559
716
@@ -590,3 +747,24 @@ def value_to_annotation(a: object) -> type | None:
590
747
591
748
592
749
converter (FunctionType , UnstableFn , _convert_function )
750
+
751
+
752
+ def _extract_lit (e : BaseExpr ) -> bindings ._Literal :
753
+ """
754
+ Special case extracting literals to make this faster by using termdag directly.
755
+ """
756
+ report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , e ), 0 )
757
+ assert isinstance (report , bindings .Best )
758
+ term = report .term
759
+ assert isinstance (term , bindings .TermLit )
760
+ return term .value
761
+
762
+
763
+ def _extract_call (e : BaseExpr ) -> CallDecl :
764
+ """
765
+ Extracts the call form of an expression
766
+ """
767
+ extracted = cast (RuntimeExpr , (EGraph .current or EGraph ()).extract (e ))
768
+ expr = extracted .__egg_typed_expr__ .expr
769
+ assert isinstance (expr , CallDecl )
770
+ return expr
0 commit comments