-
Notifications
You must be signed in to change notification settings - Fork 135
Add rewrite for softplus(log(x)) -> log1p(x)
#1452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -453,6 +453,13 @@ def local_exp_log_nan_switch(fgraph, node): | |||
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) | |||
return [new_out] | |||
|
|||
# Case for softplus(log(x)) -> log1p(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick I prefer to refer to it by log1pexp
, which we have as an alias to softplus:
pytensor/pytensor/tensor/math.py
Line 2474 in ff98ab8
log1pexp = softplus |
Also we can add a similar case for log1mexp
?
Hmm I don't know if that introduces numerical precision issues... However the somewhat converse |
data_invalid = data_valid - 2 | ||
|
||
x = fmatrix() | ||
f = function([x], softplus(log(x)), mode=self.mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want you can check against the expected graph directly, something like
assert equal_computations(f.maker.fgraph.outputs, [pt.switch(x > 0, pt.log1p(x), np.asarray([[np.nan]], dtype="float32")])
Or something like that. This is not a request!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1452 +/- ##
=======================================
Coverage 82.12% 82.12%
=======================================
Files 211 211
Lines 49757 49762 +5
Branches 8819 8820 +1
=======================================
+ Hits 40862 40867 +5
Misses 6715 6715
Partials 2180 2180
🚀 New features to boost your workflow:
|
Description
This PR adds the simple rewrite
softplus(log(x)) -> log1p(x)
.I could also extend it to cover the case
softplus(-log(x)) -> log1p(1/x)
. However I have noticed that even the simpleexp(-log(x)) -> 1/x
is missing, so I wonder if there is an underlying reason to avoid such simplifications, that I am not aware of.Also, it would be helpful to get feedback from the community before proceeding further with the PR to ensure that the proposed simplification(s) align with the library's design principles and don't introduce any unintended consequences.
Likewise, I would like to know if the approach and location (within the file and test) of the change is appropriate. I am not sufficiently familiar with the code to understand whether the approach I have used is better/worse than a
PatternNodeRewriter
.Related Issue
softplus(log(x)) -> log1p(x)
#1451Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1452.org.readthedocs.build/en/1452/