Skip to content

Refactor and update QR Op #1518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 2, 2025

Description

This PR updates the QRFull Op, adding static shape checking, infer_shape, and destroy_map. It also optimizes the perform method for the C backend, and tries to improve the gradient graph by checking static shapes (to avoid an ifelse).

I renamed it to QR, because I don't know what was Full about the old one. I also moved it from the numpy implementation to scipy, which gives us all the usual benefits (inplace, etc). I also went ahead and unpacked the scipy wrapper and used the LAPACK functions directly. This will give us better error handling (that is to say, none -- it should eventually return a matrix of NaN on failure) and some performance boost by caching workspace requirements.

Still a WIP, because it breaks everything by moving QR from nlinalg to slinalg. I thought about using this as an opportunity to finally eliminate this distinction and go to a more logical organization (linalg/decomposition/qr.py), but then decided against it for now. Needs discussion.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1518.org.readthedocs.build/en/1518/

@jessegrabowski jessegrabowski added enhancement New feature or request maintenance linalg Linear algebra labels Jul 2, 2025
@jessegrabowski jessegrabowski force-pushed the qr-shape-inference branch 2 times, most recently from c2e08a2 to 5bc044c Compare July 3, 2025 00:48
match self.mode:
case "full":
outputs = [
tensor("Q", shape=(M, M), dtype=out_dtype),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better not to give names to intermediate variables, this would shadow a user Q, if they wanted to do eval({"Q": something}) or find a variable by name

Comment on lines +1866 to +1867
self.last_shape_cache = tuple(last_shape_cache)
self.lwork_cache = tuple(lwork_cache)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't mutate an Op, it can be reused multiple times. If you want, mutate the tag of an apply node, since that's unique in a graph, whereas an Op is not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I know this is dumb and I was hoping to push a change before you arrived with this comment. Mission failed.

Comment on lines +1874 to +1875
assert x.ndim == 2, "The input of qr function should be a matrix."
M, N = x.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this in make_node, the perform shouldn't really get anything the make_node didn't allow

Comment on lines +2015 to +2039
if shapes_unknown or M_static >= N_static:
# gradient expression when m >= n
M = R @ _H(dR) - _H(dQ) @ Q
K = dQ + Q @ _copyltu(M)
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))

if not shapes_unknown:
return [A_bar_m_ge_n]

# We have to trigger both branches if shapes_unknown is True, so this is purposefully not an elif branch
if shapes_unknown or M_static < N_static:
# gradient expression when m < n
Y = A[:, m:]
U = R[:, :m]
dU, dV = dR[:, :m], dR[:, m:]
dQ_Yt_dV = dQ + Y @ _H(dV)
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
Y_bar = Q @ dV
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)

if not shapes_unknown:
return [A_bar_m_lt_n]

return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of duplicating code logic can you add a eager_ifelse? I did something like that in this WIP PR: 0dc2a1e

def _eager_switch(
    cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable
) -> TensorVariable:
    # Do not create a switch if cond is True/False
    # We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64
    # It also simplifies immediately the graph that's returned
    if isinstance(cond, bool):
        return a if cond else b
    return cast(TensorVariable, switch(cond, a, b))

This way the code is the same whether you can do eagerly or not? Or do you actually manage to return something different in the branches if you know in advance?

Copy link
Member Author

@jessegrabowski jessegrabowski Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would I use the symbolic shapes as cond in this case? Is there a point during rewrites where shapes are evaluated if possible?

That is to say, I did it this way because I thought I needed to check the static shape and dynamic shape separately

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see is it a problem if the ifelse only disappears during constant fold in compilation?

If yes I can bother with thinking of a solution

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no that's fine. Is that what would happen with _eager_switch? I don't see how if so

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear though there's no logic duplication here, only the extra returns inside each if block. The QR gradient always has two cases, depending on the shape of the input matrix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some complexity because you're trying to use static shapes and prune branches of the ifelse eagerly. PyTensor will figure out easily enough M > N in an ifelse if it has static shapes and prune the useless ifelse branch.

If there's a reason why you can't wait for pytensor to do this (in the subtensor PR it just lead to a giant graph complexity) we can try to create an eager_ifelse that tries to see if M > N and only return one branch if that's the case, so that the L_Op code looks pretty much the same regardless of whether it can be removed or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request linalg Linear algebra maintenance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add infer_shape method to QRFull
2 participants