-
Notifications
You must be signed in to change notification settings - Fork 135
Optimize matmuls involving block diagonal matrices #1493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimize matmuls involving block diagonal matrices #1493
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces an optimization that rewrites matrix multiplications involving block diagonal matrices into separate smaller multiplications and concatenations, yielding significant performance gains. It also adds tests to verify the rewrite and benchmarks to measure its impact.
- Implement
local_block_diag_dot_to_dot_block_diag
rewrite inmath.py
- Import and wire up necessary primitives (
split
,join
,BlockDiagonal
) - Add unit tests and benchmarks in
test_math.py
to validate correctness and performance
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
pytensor/tensor/rewriting/math.py | Added the local_block_diag_dot_to_dot_block_diag rewrite and required imports (split , join , BlockDiagonal ) |
tests/tensor/rewriting/test_math.py | Added tests (test_local_block_diag_dot_to_dot_block_diag ) and benchmarks (test_block_diag_dot_to_dot_concat_benchmark ) |
Comments suppressed due to low confidence (1)
pytensor/tensor/rewriting/math.py:191
- The name
Blockwise
is referenced but not imported, which will raise aNameError
if the first condition is false. Addfrom pytensor.tensor.slinalg import Blockwise
(or the correct module) at the top of the file.
or isinstance(x.owner.op, Blockwise)
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (85.71%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1493 +/- ##
=======================================
Coverage ? 81.98%
=======================================
Files ? 231
Lines ? 52231
Branches ? 9196
=======================================
Hits ? 42822
Misses ? 7098
Partials ? 2311
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some minor optimization questions
# non-block diagonal, and return a new block diagonal | ||
if check_for_block_diag(x) and not check_for_block_diag(y): | ||
components = x.owner.inputs | ||
y_splits = split( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this and the join along the 0th axis assuming a BlockwiseBlockDiagonal without batch dims?
Also not sure why you look for Dot but not Blockwise of _matrix_matrix_matmul. It doesn't always get rewritten as a Dot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess because you don't look for batch dot you can only have a BlockDiagonal without batch dims. That's fine but maybe a bit implicit. You can also wait for the useless BlockwiseBlockdiagonal to be rewritten as BlockDiagonal and only track that.
More importantly because you track a regular dot and not the matmul you may have a vector * matrix or matrix * vector product. Does the rewrite handle these correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is what's implied it's not on purpose. I'll modify it to account for blockwise dot.
There's no canonical dot
form we rewrite to in an intermediate step to make reasoning about graphs easier? It seems nuts to have to to look for a bunch of different _matrix_matrix_matmul
or _matrix_vec_matmul
or dot22
or whatever.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm simplifying every blockwise as blockwise 2x2 dot (ie matmul) in #1471
The dot22 and dot22scalar are stuff from the blas pipeline and I've been hesitant to touch it. As first steps I would like to move them after specialize and to get rid of dot22scalar (should just be gemm). Those blas stuff should also work with blockwise but they currently don't.
Anyway if you target blockwise and core 2x2dot in this PR that should be the most robust going forward even if it misses some cases now. I suggest you explicitly exclude the vector matrix dots for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipping matrix-vector is a bit of a bummer
Description
This PR adds a rewrite to optimize matrix multiplication involving block diagonal matrices. When we have a a matrix
X = BlockDiag(A, B)
, when you doZ = X @ Y
, there's no interaction between terms in theA
part andB
part of theX
matrix. So the dot can be instead computed asrow_stack(A @ Y[:X.shape[0]], B @ Y[X.shape[0]:]
(or in the general case,Y
can be split inton
pieces with appropriate shapes, and dorow_stack([diag_component @ y_split for diag_component, y_split in zip(BlockDiag.inputs, split(Y, *args)])
. If the case where the blockdiag matrix is right-multiplying, you instead col_stack and slice on axis=1.Anyway, it's a lot faster to do this, because matmuls scale really badly in the dimension of the input, so doing two smaller operations is preferred. Here are the benchmarks, small has
n=10
, medium hasn=100
, large hasn=1000
. But in all cases it shows at least 2x speedup.Related Issue
block_diag(a, b) @ c
#1044Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1493.org.readthedocs.build/en/1493/