-
Notifications
You must be signed in to change notification settings - Fork 136
Refactor and update QR Op #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c2e08a2
to
5bc044c
Compare
5bc044c
to
be949cd
Compare
match self.mode: | ||
case "full": | ||
outputs = [ | ||
tensor("Q", shape=(M, M), dtype=out_dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better not to give names to intermediate variables, this would shadow a user Q
, if they wanted to do eval({"Q": something})
or find a variable by name
self.last_shape_cache = tuple(last_shape_cache) | ||
self.lwork_cache = tuple(lwork_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't mutate an Op, it can be reused multiple times. If you want, mutate the tag of an apply node, since that's unique in a graph, whereas an Op is not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I know this is dumb and I was hoping to push a change before you arrived with this comment. Mission failed.
assert x.ndim == 2, "The input of qr function should be a matrix." | ||
M, N = x.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this in make_node, the perform shouldn't really get anything the make_node didn't allow
if shapes_unknown or M_static >= N_static: | ||
# gradient expression when m >= n | ||
M = R @ _H(dR) - _H(dQ) @ Q | ||
K = dQ + Q @ _copyltu(M) | ||
A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) | ||
|
||
if not shapes_unknown: | ||
return [A_bar_m_ge_n] | ||
|
||
# We have to trigger both branches if shapes_unknown is True, so this is purposefully not an elif branch | ||
if shapes_unknown or M_static < N_static: | ||
# gradient expression when m < n | ||
Y = A[:, m:] | ||
U = R[:, :m] | ||
dU, dV = dR[:, :m], dR[:, m:] | ||
dQ_Yt_dV = dQ + Y @ _H(dV) | ||
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q | ||
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) | ||
Y_bar = Q @ dV | ||
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) | ||
|
||
if not shapes_unknown: | ||
return [A_bar_m_lt_n] | ||
|
||
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of duplicating code logic can you add a eager_ifelse
? I did something like that in this WIP PR: 0dc2a1e
def _eager_switch(
cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable
) -> TensorVariable:
# Do not create a switch if cond is True/False
# We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64
# It also simplifies immediately the graph that's returned
if isinstance(cond, bool):
return a if cond else b
return cast(TensorVariable, switch(cond, a, b))
This way the code is the same whether you can do eagerly or not? Or do you actually manage to return something different in the branches if you know in advance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would I use the symbolic shapes as cond in this case? Is there a point during rewrites where shapes are evaluated if possible?
That is to say, I did it this way because I thought I needed to check the static shape and dynamic shape separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see is it a problem if the ifelse only disappears during constant fold in compilation?
If yes I can bother with thinking of a solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no that's fine. Is that what would happen with _eager_switch
? I don't see how if so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be clear though there's no logic duplication here, only the extra returns inside each if
block. The QR gradient always has two cases, depending on the shape of the input matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some complexity because you're trying to use static shapes and prune branches of the ifelse eagerly. PyTensor will figure out easily enough M > N in an ifelse if it has static shapes and prune the useless ifelse branch.
If there's a reason why you can't wait for pytensor to do this (in the subtensor PR it just lead to a giant graph complexity) we can try to create an eager_ifelse
that tries to see if M > N and only return one branch if that's the case, so that the L_Op code looks pretty much the same regardless of whether it can be removed or not.
Description
This PR updates the QRFull Op, adding static shape checking, infer_shape, and destroy_map. It also optimizes the perform method for the C backend, and tries to improve the gradient graph by checking static shapes (to avoid an ifelse).
I renamed it to QR, because I don't know what was Full about the old one. I also moved it from the numpy implementation to scipy, which gives us all the usual benefits (inplace, etc). I also went ahead and unpacked the scipy wrapper and used the LAPACK functions directly. This will give us better error handling (that is to say, none -- it should eventually return a matrix of NaN on failure) and some performance boost by caching workspace requirements.
Still a WIP, because it breaks everything by moving QR from nlinalg to slinalg. I thought about using this as an opportunity to finally eliminate this distinction and go to a more logical organization (linalg/decomposition/qr.py), but then decided against it for now. Needs discussion.
Related Issue
infer_shape
method toQRFull
#1511Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1518.org.readthedocs.build/en/1518/