Skip to content

RFC: add support for LU factorization in the linalg extension #627

@ogrisel

Description

@ogrisel

It seems that many libraries that are candidates to implement the Array API namespace already implement the LU factorization (with variations in API and with the notable exception of numpy).

However LU is not part of the list of linear algebra operations of the current state of the SPEC:

Are there any plans to consider it for inclusion?

Activity

rgommers

rgommers commented on May 8, 2023

@rgommers
Member

Thanks for asking @ogrisel. I had a look through the initial issues which considered the various linalg APIs, and LU decomposition was not considered at all there. The main reason I think being that the overview started from what is present in NumPy, and then looked at matching APIs in other libraries.

I think it's an uphill battle for now. It would require adding it to numpy.linalg, moving it in other libraries with numpy-matching APIs (e.g., https://docs.cupy.dev/en/stable/reference/scipy_linalg.html#decompositions is in the wrong place), and then aligning on APIs also with PyTorch & co. Finally, there's lu but also lu_solve and lu_factor - would it be just one of those, or 2/3?

It seems to me that LU decomposition is important enough that it's worth working on. So we could figure out what the optimal API for it would be, and then adding it to array-api-compat so it can be used in scikit-learn and SciPy. That can be done on pretty short notice I think. From there to actually standardizing it would take quite a long time I suspect (but nothing is really blocked on not having that done).

rgommers

rgommers commented on May 17, 2023

@rgommers
Member

The signatures to consider:

  • SciPy/cupyx.scipy/jax.scipy: lu(a, permute_l=False, overwrite_a=False, check_finite=True)
  • PyTorch: torch.linalg.lu(A, *, pivot=True, out=None)

The overwrite_a, check_finite and out keywords should all be out of scope for the standard.

The permute_l/pivot keywords do seem relevant to include. They control the return values in a different way. SciPy's permute_l returns 3 arrays if False, 2 arrays if True. That breaks a key design rule for the array API standard (no polymorphic APIs), so we can't do that. The PyTorch pivot=True behavior is okay, it always returns: a named tuple (P, L, U), and leaves P as an empty array for the non-default pivot=False.

PyTorch defaults to partial pivoting, and the keyword allows no pivoting. An LU decomposition with full pivoting is also a thing mathematically, but it does not seem implemented. JAX also has jax.lax.linalg.lu, which only does partial pivoting.

So it seems like lu(x, /) -> namedtuple(array, array, array): which defaults to partial pivoting is the minimum needed, the question is whether the other pivoting mode(s) is/are needed.

rgommers

rgommers commented on May 17, 2023

@rgommers
Member

dask.array.linalg.lu has no keywords at all, and no info in the docstring about what is implemented. From the tests it's clear that it matches the SciPy default (permute_l=False).

rgommers

rgommers commented on May 17, 2023

@rgommers
Member

For PyTorch, the non-default flavor is only implemented on GPU:

>>> A = torch.randn(3, 2)
... P, L, U = torch.linalg.lu(A)
>>> A = torch.randn(3, 2)
... P, L, U = torch.linalg.lu(A, pivot=False)
Traceback (most recent call last):
  Cell In[6], line 2
    P, L, U = torch.linalg.lu(A, pivot=False)
RuntimeError: linalg.lu_factor: LU without pivoting is not implemented on the CPU

Its docstring also notes: The LU decomposition without pivoting may not exist if any of the principal minors of A is singular.

tl;dr maybe the best way to go is to only implement partial pivoting?

ogrisel

ogrisel commented on May 17, 2023

@ogrisel
Author

Maybe we can start with a function with no argument that always returns PLU (that is only implement scipy's permute_L=False and torch's pivot=True) and it will be up to the consumer to compute.

On the other hand, I think it would be good to have an option wot do the PL product automatically and avoid allocating P. Should array API expose two methods? xp.linalg.lu that outputs a 3-tuple (P, L, U) a second function xp.linalg.permuted_lu that precomputes the PL product and always outputs a 2-tuple (P @ L, U)?

Its docstring also notes: The LU decomposition without pivoting may not exist if any of the principal minors of A is singular.

Also note, from PyTorch's doc:

The LU decomposition is almost never unique, as often there are different permutation matrices that can yield different LU decompositions. As such, different platforms, like SciPy, or inputs on different devices, may produce different valid decompositions.

Such a disclaimer should probably be mentioned in the Array API spec.

ogrisel

ogrisel commented on May 17, 2023

@ogrisel
Author

Note that scipy.linalg.lu calls:

flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)

and flu is therefore not polymorphic internally but it p is a 1x1 array with a single 0 value when permute_l is True.

rgommers

rgommers commented on May 18, 2023

@rgommers
Member

@ogrisel I opened gh-630 for the default (partial pivoting) case that seems supportable by all libraries.

On the other hand, I think it would be good to have an option wot do the PL product automatically and avoid allocating P. Should array API expose two methods? xp.linalg.lu that outputs a 3-tuple (P, L, U) a second function xp.linalg.permuted_lu that precomputes the PL product and always outputs a 2-tuple (P @ L, U)?

Perhaps. The alternative of having an empty P like PyTorch does may work, but it's not ideal. JAX would have to preallocate a full-size array in case a keyword is used and it's not literal True/False.

Given that this use case seems more niche and it's not supported by Dask and by PyTorch on CPU, and you don't need it now in scikit-learn, it seems waiting for a stronger need for this seems like the way to go here though.

ogrisel

ogrisel commented on May 19, 2023

@ogrisel
Author

We do use the "permute_l=True" case in scikit-learn.

ogrisel

ogrisel commented on May 19, 2023

@ogrisel
Author

It would be easy to provide a fallback implementation that uses an extra temporary allocation + mm product for libraries that do not natively support scipy's permute_l=True.

But it's not clear if pytorch' pivot=False is equivalent to scipy permute_l=True or doing something different.

35 remaining items

ogrisel

ogrisel commented on May 25, 2023

@ogrisel
Author

Thanks for the explicit summary.

p_indices does not change the shape of the output.

I am not sure what you mean by shape. I would expect the following:

>>> A, p_indices = xp.linalg.lu_factor(X, p_indices=True)
>>> p_indices.dtype.kind
'i'
>>> p_indices.ndim
1

while:

>>> A, P = xp.linalg.lu_factor(X, p_indices=False)
>>> P.dtype.kind
'f'
>>> P.ndim
2

this is what I meant by "soft polymorphic outputs".

lezcano

lezcano commented on May 25, 2023

@lezcano
Contributor

When p_indices=False, xp.linalg.lu_factor would return the same thing as scipy.linalg.lu_factor. In particular, it returns a 1D indexed array that can be fed into lu_solve.

If you want to recover the full 2D matrix P, you would do as discussed in the penultimate point in #627 (comment)

[...] one may compute the matrix P explicitly by doing P = xp.eye(n)[perm]

where perm = xp.linalg.lu_factor(X, p_indices=True)[1]

ogrisel

ogrisel commented on May 30, 2023

@ogrisel
Author

I now realize that I misunderstood how the piv arrays returned by LAPACK actually work. The example in the docstring of scipy.linalg.lu_factor seems to be broken.

I find that exposing those low level 1-based indexed pivot info quite confusing anyway. I have the feeling this should not be part of the standard API or at least not with the default options.

ilayn

ilayn commented on May 30, 2023

@ilayn

I can fix the docs if something is broken. What kind of issue do you have with it?

And about exposing it, yes that's what I meant in the last part of #627 (comment) about the "mistake". Note that they are pivot "swaps" on A not permutations on L.

lezcano

lezcano commented on May 30, 2023

@lezcano
Contributor

Which API would you guys then propose that can be implemented with the current tools (BLAS/cuBLAS/cuSolver/MAGMA) that allows for solving linear systems fast, possibly with a common lhs?

ilayn

ilayn commented on May 30, 2023

@ilayn

My vote goes to Factorize class like Julia or Rust implements. The name can be something else but polyalgorithm is a must in this day and age for these operations. lu_factor, chol_solve etc in my opinion belong to the fortran age.

If you have an array A and pass it to Factorize or whatever the name, then you store your array factorization depending on the properties of A. It can be discovered at load time, or can be enforced as Factorize(A) or Factorize(A, assume_a='sym'). Then you can call solve on it and it will use whatever factorization is stored internally. You can repeat solve with different RHS or can perform different ops like .update(), .downdate(), det, rcond() in a lazy loading way. It can also hold multiple factorizations.

If you need explicit factorization then you can still call lu, qr and so on. That's also sparse and dense agnostic and pretty much what matlab offers as the most famous convenience. You don't focus on the operation itself you focus on the data you provide and the algorithms handle the complication.

Otherwise we will repeat the same fortran api yet another two decades. That's pretty much what I wanted to do in scipy/scipy#12824 but then I lost my way due to other things. It tries to address the discovery part for solve. The idea is that the linalg APIs have moved on and it is quite inconvenient to fight with lapack api just to get a single operation with internal arrays, convert things from internal lapack representation to usable Scientific Python ways etc.

We are now trying to reduce our Fortran77 codebase in SciPy which is taking quite some time and after that I am hoping to get back to this again. Anyways, that's a rough sketch but that's kind of my wish to have in the end.

ogrisel

ogrisel commented on May 31, 2023

@ogrisel
Author

I can fix the docs if something is broken. What kind of issue do you have with it?

I thought the example in the docstring was broken (I thought I had copied and pasted the snippets and it would fail) but it's actually not the case. Still I think it can be helpful to be more explicit: scipy/scipy#18594.

For some reason I thought that I read somewhere that piv would be 1-based indexed (as in LAPACK) but it's actually 0-based indexed in scipy.

changed the title [-]LU factorization in the linalg extension[/-] [+]RFC: add support for LU factorization in the linalg extension[/+] on Apr 4, 2024
added
RFCRequest for comments. Feature requests and proposed changes.
on Apr 4, 2024
ogrisel

ogrisel commented on May 7, 2024

@ogrisel
Author

I find the proposal in #627 (comment) interesting. Would it be possible to implement it on top of the existing building blocks available in numpy/scipy, torch and co?

If it can be done, (e.g. maybe in a proof of concept repo), maybe it can help move the spec discussion forward if it is backed by working code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.topic: Linear AlgebraLinear algebra.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @ogrisel@rgommers@nikitaved@ilayn@kgryte

      Issue actions

        RFC: add support for LU factorization in the linalg extension · Issue #627 · data-apis/array-api