-
Notifications
You must be signed in to change notification settings - Fork 136
Do not skip validation between consecutive Elemwise inplace replacements #1494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fdc5153
to
b05557e
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (92.85%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1494 +/- ##
==========================================
+ Coverage 81.98% 82.02% +0.04%
==========================================
Files 231 231
Lines 52192 52216 +24
Branches 9185 9186 +1
==========================================
+ Hits 42790 42831 +41
+ Misses 7094 7080 -14
+ Partials 2308 2305 -3
🚀 New features to boost your workflow:
|
65f2e35
to
6475c85
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR ensures validation after each in-place rewrite in the Elemwise optimizer, refactors shared logic for Blockwise in-place rewrites, and adds regression tests for both cases.
- Enforce full graph validation between consecutive Elemwise in-place replacements
- Refactor Blockwise in-place optimizer to extend the shared
InplaceGraphOptimizer
- Add tests for regression in Elemwise and partial in-place behavior in Blockwise
Reviewed Changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated no comments.
File | Description |
---|---|
tests/tensor/test_blockwise.py | Add test_partial_inplace to cover Blockwise in-place behavior |
tests/tensor/rewriting/test_elemwise.py | Add test_InplaceElemwiseOptimizer_bug regression test for #1420 |
pytensor/tensor/rewriting/blockwise.py | Refactor Blockwise in-place rewrite using InplaceGraphOptimizer |
pytensor/graph/destroyhandler.py | Introduce inplace_candidates helper for unified candidate logic |
Comments suppressed due to low confidence (4)
tests/tensor/rewriting/test_elemwise.py:1535
- [nitpick] Test function names should use snake_case. Consider renaming this to
test_inplace_elemwise_optimizer_bug
for consistency with pytest conventions.
def test_InplaceElemwiseOptimizer_bug():
tests/tensor/rewriting/test_elemwise.py:1544
- The test references
Elemwise
but there is no corresponding import. Please addfrom pytensor.tensor.rewriting.elemwise import Elemwise
at the top of the file.
out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])
pytensor/tensor/rewriting/blockwise.py:4
- The imported
vectorize_node
is no longer used in this file. Consider removing this import to clean up unused code.
from pytensor.graph.replace import vectorize_node
pytensor/tensor/rewriting/blockwise.py:5
- Neither
copy_stack_trace
norout2in
are used after the refactor. Removing these unused imports will improve code clarity.
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
Actually I suspect with my second commit the inplace optimizer can't accidentally introduce invalid graphs anymore. If that's the case we could do a single validate at the end of the rewrite. I didn't actually try to see what's so slow about validate. That's less priority than just the fix here |
6475c85
to
1438953
Compare
Closes #1420
There was a performance-related hack in the ElemwiseInplaceOptimizer where it tried to avoid validating the graph after replacing each node.
This is a bad idea because it may revert valid rewrites that happen to be caught in the same "check" window as an invalid rewrite. More importantly since we started inplacing on multi-output Elemwise, it could trigger an exception when trying to call
has_destroyers
on subsequent nodes, due to a previous invalid replacement.It was hard to track down this issue because the special behavior was only triggered once a graph had more than 500 nodes.
After the refactor I noticed that most of the logic of the rewrite can be shared with Blockwise, so I went ahead and refactored it, which also closes #1457
📚 Documentation preview 📚: https://pytensor--1494.org.readthedocs.build/en/1494/