-
Notifications
You must be signed in to change notification settings - Fork 139
Implement OpPattern
for more flexible tracks
and PatternNodeRewriter
#1594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c103e75
to
9e89eca
Compare
OpPattern
for more flexible tracks
and PatternNodeRewriter
9e89eca
to
ea37006
Compare
if inp.dtype != ottype: | ||
inp = cast(inp, ottype) | ||
return [inp] | ||
for pair in ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids triggering the rewrite too many times, we spent an unreasonable amount of time in this rewrite the way it was structured, and rarely ended up using it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This info should be put into a comment, since it's unusual (in our code base anyway) that we're using codegen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defining rewrites in a loop is not what I would call codegen :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment, it does not have the trigger word codegen in it
ea37006
to
b5fb239
Compare
b5fb239
to
1d13782
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
f28fb72
to
dd8cee9
Compare
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (84.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1594 +/- ##
==========================================
- Coverage 81.70% 81.65% -0.05%
==========================================
Files 230 230
Lines 52950 53002 +52
Branches 9404 9414 +10
==========================================
+ Hits 43262 43280 +18
- Misses 7256 7274 +18
- Partials 2432 2448 +16
🚀 New features to boost your workflow:
|
Only mypy missing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to be able to match on Op properties as well. Do you think that's in-scope for this first pass, or too complex?
Thinking specifically about dimshuffle, being able to do something like OpPattern(Dimshuffle, WithProperty('is_transpose', True))
@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp): | |||
nfunc_spec = ("sign", 1, 1) | |||
|
|||
@staticmethod | |||
def output_types_preference(x): | |||
def _output_types_preference(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _output_types_preference(x): | |
def _get_output_types_preference(x): |
Since it's a method in this case, not a member variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what init of the base class will look for if I don't pass a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check for ScalarLoop.__init__
, this plays a role there, and the name matters
pytensor/tensor/rewriting/linalg.py
Outdated
# N.B. this can be further reduced to a yet-unwritten cho_solve Op | ||
# __if__ no other Op makes use of the L matrix during the | ||
# stabilization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you only changed the indentation level and thus don't care, but this Op exists now, so this note is out of date.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this rewrite go directly to cho_solve then?
You don't need a special object, it already works if Dimshuffle has that property/attribute. It uses |
I can't reply to one of your comments so here it is:
This API is very strange but I don't want to change it in this PR. They wanted this to be easy to change without creating new Op types (it's used for inplace for instance). Only change I did is the attribute always exists |
dd8cee9
to
f76771a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a flexible OpPattern
class for use in tracks
and PatternNodeRewriter
, enabling more powerful pattern matching in graph rewrites. Key improvements include replacing many hardcoded Op type checks with OpPattern-based tracks and adding support for parameterized Op matching.
- Implements
OpPattern
class for flexible Op parameter matching in unification and tracks - Refactors existing node rewriters to use OpPattern tracks, reducing redundant code and improving performance
- Enhances
PatternNodeRewriter
to support callable output patterns and OpPattern input patterns
Reviewed Changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
pytensor/graph/rewriting/unify.py | Implements OpPattern class and LiteralString for flexible Op matching |
pytensor/graph/rewriting/basic.py | Enhances node rewriter infrastructure to support OpPattern tracks |
pytensor/tensor/rewriting/basic.py | Adds elemwise_of helper and refactors some node rewriters to use OpPattern |
pytensor/tensor/rewriting/blockwise.py | Adds blockwise_of helper for creating Blockwise OpPatterns |
pytensor/tensor/rewriting/math.py | Refactors inverse function pairs using PatternNodeRewriter |
pytensor/tensor/rewriting/linalg.py | Updates linear algebra rewrites to use blockwise_of OpPattern tracks |
pytensor/tensor/rewriting/elemwise.py | Converts elemwise rewrites to use elemwise_of OpPattern tracks |
pytensor/tensor/rewriting/subtensor.py | Updates subtensor rewrites to use blockwise_of patterns |
pytensor/tensor/elemwise.py | Adds is_matrix_transpose property to DimShuffle |
pytensor/scalar/basic.py | Fixes Sign output_types_preference implementation |
tests/graph/rewriting/test_unify.py | Adds tests for OpPattern unification |
tests/graph/rewriting/test_basic.py | Adds tests for OpPattern in PatternNodeRewriter |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
fe64030
to
021017c
Compare
Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
021017c
to
e12fcd4
Compare
I've told mypy to turn a blind eye enough that tests are passing :) |
Spinoff from #1592
It introduces a new
OpPattern
that can be used fortracks
and inPatternNodeRewriter
You can now do:
which wasn't possible before because there are many instances of Solve, unless you defined all ahead of time and enumerated them all in the tracks (including all variations of
on_error
andb_ndim
...).I added OpPattern tracks to almost all Blockwise (and some old Elemwise). Running the tests just in
tests.tensor.rewriting.test_linalg
, I saw a drop in the number of times the local node rewriters were called from 70k to 26k, with only 4k new calls to OpPattern parameter checks.The bulk of the remaining calls from mostly from isolated in2out rewrites, where we don't iterate multiple times and therefore there's not much caching/sharing of tracks we can do. This will be further improved in #1607
The new
OpPattern
can also be used to beef up thePatternNodeRewriter
I further allowed one to use a callable for the output expression (or further reject the rewrite), which is much more flexible.
Here is an example that was not possible before (specifically matching any CAReduce with axis=None) and binding
scalar_op
to be used in the outputElemwise
:📚 Documentation preview 📚: https://pytensor--1594.org.readthedocs.build/en/1594/