|
13 | 13 | from collections.abc import Callable, Iterable, Sequence
|
14 | 14 | from functools import _compose_mro, partial # type: ignore
|
15 | 15 | from itertools import chain
|
16 |
| -from typing import TYPE_CHECKING, Literal |
| 16 | +from typing import Literal |
17 | 17 |
|
18 |
| -import pytensor |
19 | 18 | from pytensor.configdefaults import config
|
20 | 19 | from pytensor.graph import destroyhandler as dh
|
21 | 20 | from pytensor.graph.basic import (
|
|
30 | 29 | from pytensor.graph.features import AlreadyThere, Feature
|
31 | 30 | from pytensor.graph.fg import FunctionGraph, Output
|
32 | 31 | from pytensor.graph.op import Op
|
| 32 | +from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars |
33 | 33 | from pytensor.graph.utils import AssocList, InconsistencyError
|
34 | 34 | from pytensor.misc.ordered_set import OrderedSet
|
35 | 35 | from pytensor.utils import flatten
|
36 | 36 |
|
37 | 37 |
|
38 |
| -if TYPE_CHECKING: |
39 |
| - from pytensor.graph.rewriting.unify import Var |
40 |
| - |
41 |
| - |
42 | 38 | _logger = logging.getLogger("pytensor.graph.rewriting.basic")
|
43 | 39 |
|
44 | 40 | RemoveKeyType = Literal["remove"]
|
@@ -1414,8 +1410,6 @@ def __init__(
|
1414 | 1410 | frequent `Op`, which will prevent the rewrite from being tried as often.
|
1415 | 1411 |
|
1416 | 1412 | """
|
1417 |
| - from pytensor.graph.rewriting.unify import convert_strs_to_vars |
1418 |
| - |
1419 | 1413 | var_map: dict[str, Var] = {}
|
1420 | 1414 | self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
|
1421 | 1415 | self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
|
@@ -1457,9 +1451,6 @@ def transform(self, fgraph, node, get_nodes=True):
|
1457 | 1451 | if ret is not False and ret is not None:
|
1458 | 1452 | return dict(zip(real_node.outputs, ret, strict=True))
|
1459 | 1453 |
|
1460 |
| - if node.op != self.op: |
1461 |
| - return False |
1462 |
| - |
1463 | 1454 | if len(node.outputs) != 1:
|
1464 | 1455 | # PatternNodeRewriter doesn't support replacing multi-output nodes
|
1465 | 1456 | return False
|
@@ -1488,11 +1479,13 @@ def transform(self, fgraph, node, get_nodes=True):
|
1488 | 1479 |
|
1489 | 1480 | [old_out] = node.outputs
|
1490 | 1481 | if not old_out.type.is_super(ret.type):
|
| 1482 | + from pytensor.tensor.type import TensorType |
| 1483 | + |
1491 | 1484 | # Type doesn't match
|
1492 | 1485 | if not (
|
1493 | 1486 | self.allow_cast
|
1494 |
| - and isinstance(old_out.type, pytensor.tensor.TensorType) |
1495 |
| - and isinstance(ret.type, pytensor.tensor.TensorType) |
| 1487 | + and isinstance(old_out.type, TensorType) |
| 1488 | + and isinstance(ret.type, TensorType) |
1496 | 1489 | ):
|
1497 | 1490 | return False
|
1498 | 1491 |
|
@@ -2744,10 +2737,10 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
|
2744 | 2737 | otherwise.
|
2745 | 2738 |
|
2746 | 2739 | """
|
2747 |
| - if isinstance(f_or_fgraph, pytensor.compile.function.types.Function): |
2748 |
| - fgraph = f_or_fgraph.maker.fgraph |
2749 |
| - elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph): |
| 2740 | + if isinstance(f_or_fgraph, FunctionGraph): |
2750 | 2741 | fgraph = f_or_fgraph
|
| 2742 | + elif hasattr(f_or_fgraph, "fgraph"): |
| 2743 | + fgraph = f_or_fgraph.fgraph |
2751 | 2744 | else:
|
2752 | 2745 | raise ValueError("The type of f_or_fgraph is not supported")
|
2753 | 2746 |
|
|
0 commit comments