Skip to content

Commit f72d7e5

Browse files
Add JAX dispatch for CholeskySolve Op (#1491)
* Add jax dispatch for CholeskySolve * Better typehints on user-facing `cho_solve` * Rename test
1 parent d3bbc20 commit f72d7e5

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LU,
88
BlockDiagonal,
99
Cholesky,
10+
CholeskySolve,
1011
Eigvalsh,
1112
LUFactor,
1213
PivotToPermutations,
@@ -153,3 +154,17 @@ def lu_factor(a):
153154
)
154155

155156
return lu_factor
157+
158+
159+
@jax_funcify.register(CholeskySolve)
160+
def jax_funcify_ChoSolve(op, **kwargs):
161+
lower = op.lower
162+
check_finite = op.check_finite
163+
overwrite_b = op.overwrite_b
164+
165+
def cho_solve(c, b):
166+
return jax.scipy.linalg.cho_solve(
167+
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
168+
)
169+
170+
return cho_solve

pytensor/tensor/slinalg.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,14 +376,20 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
376376
return self
377377

378378

379-
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
379+
def cho_solve(
380+
c_and_lower: tuple[TensorLike, bool],
381+
b: TensorLike,
382+
*,
383+
check_finite: bool = True,
384+
b_ndim: int | None = None,
385+
):
380386
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
381387
382388
Parameters
383389
----------
384-
(c, lower) : tuple, (array, bool)
390+
c_and_lower : tuple of (TensorLike, bool)
385391
Cholesky factorization of a, as given by cho_factor
386-
b : array
392+
b : TensorLike
387393
Right-hand side
388394
check_finite : bool, optional
389395
Whether to check that the input matrices contain only finite numbers.

tests/link/jax/test_slinalg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape):
333333
out = pt_slinalg.lu_solve(lu_and_pivots, b)
334334

335335
compare_jax_and_py([A, b], [out], [A_val, b_val])
336+
337+
338+
@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])
339+
def test_jax_cho_solve(b_shape, lower):
340+
rng = np.random.default_rng(utt.fetch_seed())
341+
L_val = rng.normal(size=(5, 5)).astype(config.floatX)
342+
A_val = (L_val @ L_val.T).astype(config.floatX)
343+
344+
b_val = rng.normal(size=b_shape).astype(config.floatX)
345+
346+
A = pt.tensor(name="A", shape=(5, 5))
347+
b = pt.tensor(name="b", shape=b_shape)
348+
c = pt_slinalg.cholesky(A, lower=lower)
349+
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
350+
351+
compare_jax_and_py([A, b], [out], [A_val, b_val])

0 commit comments

Comments
 (0)