Skip to content

Commit 99c0adf

Browse files
Merge pull request #286 from egraphs-good/fix-import-bug
Fix bug resolving types
2 parents ac78593 + 19e9898 commit 99c0adf

11 files changed

+51
-23
lines changed

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/changelog.md

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Fix bug on resolving types if not all imported to your module [#286](https://github.com/egraphs-good/egglog-python/pull/286)
8+
- Also stops special casing including `Callable` as a global. So if you previously included this in a `TYPE_CHECKING` block so it wasn
9+
available at runtime you will have to move this to a runtime import if used in a type alias.
10+
711
## 10.0.0 (2025-03-28)
812

913
- Change builtins to not evaluate values in egraph and changes facts to compare structural equality instead of using an egraph when converting to a boolean, removing magic context (`EGraph.current` and `Schedule.current`) that was added in release 9.0.0.

pyproject.toml

+13
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ ignore = [
198198
"SLF001",
199199
# allow blind exception to add context
200200
"BLE001",
201+
# Don't move type checking around so that can be accessed at runtime
202+
"TCH001",
203+
"TCH002",
204+
"TCH003",
201205
]
202206
select = ["ALL"]
203207

@@ -217,6 +221,15 @@ preview = true
217221
# Don't require annotations for tests
218222
"python/tests/**" = ["ANN001", "ANN201", "INP001"]
219223

224+
# Disable these tests instead for now since ruff doesn't support including all method annotations of decorated class
225+
# [tool.ruff.lint.flake8-type-checking]
226+
# runtime-evaluated-decorators = [
227+
# "egglog.function",
228+
# "egglog.method",
229+
# "egglog.ruleset",
230+
# ]
231+
# runtime-evaluated-base-classes = ["egglog.Expr"]
232+
220233
[tool.mypy]
221234
ignore_missing_imports = true
222235
warn_redundant_casts = true

python/egglog/builtins.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Callable
89
from fractions import Fraction
910
from functools import partial, reduce
1011
from types import FunctionType, MethodType
@@ -20,7 +21,7 @@
2021
from .thunk import Thunk
2122

2223
if TYPE_CHECKING:
23-
from collections.abc import Callable, Iterator
24+
from collections.abc import Iterator
2425

2526

2627
__all__ = [

python/egglog/egraph.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1817,12 +1817,13 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
18171817
"""
18181818
Returns a thunk which will call the function with variables of the type and name of the arguments.
18191819
"""
1820-
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1821-
# but not in the globals
1822-
globals = gen.__globals__.copy()
1823-
if "Callable" not in globals:
1824-
globals["Callable"] = Callable
1825-
hints = get_type_hints(gen, globals, frame.f_locals)
1820+
# Need to manually pass in the frame locals from the generator, because otherwise classes defined within function
1821+
# will not be available in the annotations
1822+
# combine locals and globals so that they are the same dict. Otherwise get_type_hints will go through the wrong
1823+
# path and give an error for the test
1824+
# python/tests/test_no_import_star.py::test_no_import_star_rulesset
1825+
combined = {**gen.__globals__, **frame.f_locals}
1826+
hints = get_type_hints(gen, combined, combined)
18261827
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
18271828
return list(gen(*args)) # type: ignore[misc]
18281829

python/egglog/examples/higher_order_functions.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66

77
from __future__ import annotations
88

9-
from typing import TYPE_CHECKING
9+
from collections.abc import Callable
1010

1111
from egglog import *
1212

13-
if TYPE_CHECKING:
14-
from collections.abc import Callable
15-
1613

1714
class Math(Expr):
1815
def __init__(self, i: i64Like) -> None: ...

python/egglog/examples/lambda_.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77

88
from __future__ import annotations
99

10-
from typing import TYPE_CHECKING, ClassVar
10+
from collections.abc import Callable
11+
from typing import ClassVar
1112

1213
from egglog import *
13-
14-
if TYPE_CHECKING:
15-
from collections.abc import Callable
14+
from egglog import Expr
1615

1716

1817
class Val(Expr):

python/egglog/exp/array_api.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import numbers
6161
import os
6262
import sys
63+
from collections.abc import Callable
6364
from copy import copy
6465
from types import EllipsisType
6566
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
@@ -72,9 +73,10 @@
7273
from .program_gen import *
7374

7475
if TYPE_CHECKING:
75-
from collections.abc import Callable, Iterator
76+
from collections.abc import Iterator
7677
from types import ModuleType
7778

79+
7880
# Pretend that exprs are numbers b/c sklearn does isinstance checks
7981
numbers.Integral.register(RuntimeExpr)
8082

python/tests/test_no_import_star.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import egglog as el
24

35

@@ -10,3 +12,15 @@ class Num(el.Expr):
1012
def __init__(self, value: el.i64Like) -> None: ...
1113

1214
Num(1) # gets an error "NameError: name 'i64' is not defined"
15+
16+
17+
def test_no_import_star_rulesset():
18+
"""
19+
https://github.com/egraphs-good/egglog-python/issues/283
20+
"""
21+
22+
@el.ruleset
23+
def _rules(_: el.i64Like):
24+
return []
25+
26+
el.EGraph().run(_rules)

python/tests/test_pretty.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: disable-error-code="empty-body"
22
from __future__ import annotations
33

4+
from collections.abc import Callable
45
from copy import copy
56
from functools import partial
67
from typing import TYPE_CHECKING, ClassVar
@@ -10,8 +11,6 @@
1011
from egglog import *
1112

1213
if TYPE_CHECKING:
13-
from collections.abc import Callable
14-
1514
from egglog.runtime import RuntimeExpr
1615

1716

python/tests/test_unstable_fn.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Callable
89
from functools import partial
9-
from typing import TYPE_CHECKING, ClassVar, TypeAlias
10+
from typing import ClassVar, TypeAlias
1011

1112
from egglog import *
1213

13-
if TYPE_CHECKING:
14-
from collections.abc import Callable
15-
1614

1715
class Math(Expr):
1816
def __init__(self, n: i64Like) -> None: ...

0 commit comments

Comments
 (0)