Skip to content

Commit d3bbc20

Browse files
lcitiLuca Citi
andauthored
Cover more cases of log1mexp stabilization (#1483)
* Created some tests that fail due to #1476 * Fixes 1476 and other ways to create a log1mexp * Reimplemented logmexpm1_to_log1mexp by tracking expm1 and then looking through the clients * Absorbed the rewrite log1pexp_to_softplus into the new rewrite for log1mexp * Fixed bug where I forgot to check whether result of is_neg was None or not before proceeding --------- Co-authored-by: Luca Citi <[email protected]>
1 parent 6aeed97 commit d3bbc20

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
log,
6565
log1mexp,
6666
log1p,
67+
log1pexp,
6768
makeKeepDims,
6869
maximum,
6970
mul,
@@ -2999,12 +3000,6 @@ def _is_1(expr):
29993000
tracks=[sigmoid],
30003001
get_nodes=get_clients_at_depth2,
30013002
)
3002-
log1pexp_to_softplus = PatternNodeRewriter(
3003-
(log1p, (exp, "x")),
3004-
(softplus, "x"),
3005-
values_eq_approx=values_eq_approx_remove_inf,
3006-
allow_multiple_clients=True,
3007-
)
30083003
log1p_neg_sigmoid = PatternNodeRewriter(
30093004
(log1p, (neg, (sigmoid, "x"))),
30103005
(neg, (softplus, "x")),
@@ -3016,7 +3011,6 @@ def _is_1(expr):
30163011

30173012
register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
30183013
register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
3019-
register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
30203014
register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
30213015
register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
30223016

@@ -3582,12 +3576,40 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35823576
register_specialize(local_1msigmoid)
35833577

35843578

3585-
log1pmexp_to_log1mexp = PatternNodeRewriter(
3586-
(log1p, (neg, (exp, "x"))),
3587-
(log1mexp, "x"),
3588-
allow_multiple_clients=True,
3589-
)
3590-
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
3579+
@register_stabilize
3580+
@node_rewriter([log1p])
3581+
def local_log1p_plusminus_exp(fgraph, node):
3582+
"""Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp
3583+
``log1p(exp(x)) -> log1pexp(x)``
3584+
``log1p(-exp(x)) -> log1mexp(x)``
3585+
where "-" can be "neg" or any other expression detected by "is_neg"
3586+
"""
3587+
(log1p_arg,) = node.inputs
3588+
exp_info = is_exp(log1p_arg)
3589+
if exp_info is not None:
3590+
exp_neg, exp_arg = exp_info
3591+
if exp_neg:
3592+
return [log1mexp(exp_arg)]
3593+
else:
3594+
return [log1pexp(exp_arg)] # aka softplus
3595+
3596+
3597+
@register_stabilize
3598+
@node_rewriter([expm1])
3599+
def logmexpm1_to_log1mexp(fgraph, node):
3600+
"""``log(-expm1(x)) -> log1mexp(x)``
3601+
where "-" can be "neg" or any other expression detected by "is_neg"
3602+
"""
3603+
rewrites = {}
3604+
for node in get_clients_at_depth(fgraph, node, depth=2):
3605+
if node.op == log:
3606+
(log_arg,) = node.inputs
3607+
neg_arg = is_neg(log_arg)
3608+
if neg_arg is not None and neg_arg.owner and neg_arg.owner.op == expm1:
3609+
(expm1_arg,) = neg_arg.owner.inputs
3610+
rewrites[node.outputs[0]] = log1mexp(expm1_arg)
3611+
return rewrites
3612+
35913613

35923614
# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
35933615
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(

tests/tensor/rewriting/test_math.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4438,11 +4438,22 @@ def test_local_add_neg_to_sub(first_negative):
44384438
assert np.allclose(f(x_test, y_test), exp)
44394439

44404440

4441-
def test_log1mexp_stabilization():
4441+
@pytest.mark.parametrize(
4442+
"op_name",
4443+
["log_1_minus_exp", "log1p_minus_exp", "log_minus_expm1", "log_minus_exp_minus_1"],
4444+
)
4445+
def test_log1mexp_stabilization(op_name):
44424446
mode = Mode("py").including("stabilize")
44434447

44444448
x = vector()
4445-
f = function([x], log(1 - exp(x)), mode=mode)
4449+
if op_name == "log_1_minus_exp":
4450+
f = function([x], log(1 - exp(x)), mode=mode)
4451+
elif op_name == "log1p_minus_exp":
4452+
f = function([x], log1p(-exp(x)), mode=mode)
4453+
elif op_name == "log_minus_expm1":
4454+
f = function([x], log(-expm1(x)), mode=mode)
4455+
elif op_name == "log_minus_exp_minus_1":
4456+
f = function([x], log(-(exp(x) - 1)), mode=mode)
44464457

44474458
nodes = [node.op for node in f.maker.fgraph.toposort()]
44484459
assert nodes == [pt.log1mexp]

0 commit comments

Comments
 (0)