Skip to content

Commit 229c14d

Browse files
committed
Remove unused skip_identities_fn
1 parent dfef95b commit 229c14d

File tree

2 files changed

+3
-32
lines changed

2 files changed

+3
-32
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,7 +1545,6 @@ def __init__(
15451545
in_pattern,
15461546
out_pattern,
15471547
allow_multiple_clients: bool = False,
1548-
skip_identities_fn=None,
15491548
name: str | None = None,
15501549
tracks=(),
15511550
get_nodes=None,
@@ -1602,7 +1601,6 @@ def __init__(
16021601
)
16031602
self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n"
16041603
self.allow_multiple_clients = allow_multiple_clients
1605-
self.skip_identities_fn = skip_identities_fn
16061604
if name:
16071605
self.__name__ = name
16081606
self._tracks = tracks

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3162,44 +3162,19 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
31623162
return np.allclose(x, ref, rtol=rtol, atol=atol)
31633163

31643164

3165-
def _skip_mul_1(r):
3166-
if r.owner and r.owner.op == mul:
3167-
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
3168-
if len(not_is_1) == 1:
3169-
return not_is_1[0]
3170-
3171-
3172-
def _is_1(expr):
3173-
"""
3174-
3175-
Returns
3176-
-------
3177-
bool
3178-
True iff expr is a constant close to 1.
3179-
3180-
"""
3181-
try:
3182-
v = get_underlying_scalar_constant_value(expr)
3183-
return isclose(v, 1)
3184-
except NotScalarConstantError:
3185-
return False
3186-
3187-
31883165
logsigm_to_softplus = PatternNodeRewriter(
31893166
(log, (sigmoid, "x")),
31903167
(neg, (softplus, (neg, "x"))),
31913168
allow_multiple_clients=True,
31923169
values_eq_approx=values_eq_approx_remove_inf,
3193-
skip_identities_fn=_skip_mul_1,
31943170
tracks=[sigmoid],
31953171
get_nodes=get_clients_at_depth1,
31963172
)
31973173
log1msigm_to_softplus = PatternNodeRewriter(
3198-
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
3174+
(log, (sub, 1, (sigmoid, "x"))),
31993175
(neg, (softplus, "x")),
32003176
allow_multiple_clients=True,
32013177
values_eq_approx=values_eq_approx_remove_inf,
3202-
skip_identities_fn=_skip_mul_1,
32033178
tracks=[sigmoid],
32043179
get_nodes=get_clients_at_depth2,
32053180
)
@@ -3396,10 +3371,8 @@ def local_exp_over_1_plus_exp(fgraph, node):
33963371

33973372
if len(denom_rest) == 0:
33983373
return [new_num]
3399-
elif len(denom_rest) == 1:
3400-
out = new_num / denom_rest[0]
34013374
else:
3402-
out = new_num / mul(*denom_rest)
3375+
out = new_num / variadic_mul(*denom_rest)
34033376

34043377
copy_stack_trace(node.outputs[0], out)
34053378
return [out]
@@ -3769,7 +3742,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
37693742

37703743
# 1 - sigmoid(x) -> sigmoid(-x)
37713744
local_1msigmoid = PatternNodeRewriter(
3772-
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
3745+
(sub, 1, (sigmoid, "x")),
37733746
(sigmoid, (neg, "x")),
37743747
tracks=[sigmoid],
37753748
get_nodes=get_clients_at_depth1,

0 commit comments

Comments
 (0)