Skip to content

Commit 8f10a44

Browse files
committed
.progress
1 parent 56c2251 commit 8f10a44

File tree

3 files changed

+118
-48
lines changed

3 files changed

+118
-48
lines changed
Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
from pytensor.graph import node_rewriter
2-
from pytensor.tensor._linalg.solve.tridiagonal import split_solve_tridiagonal, decompose_of_solve_tridiagonal
1+
from pytensor.graph.rewriting.basic import node_rewriter
2+
from pytensor.tensor._linalg.solve.tridiagonal import (
3+
decompose_tridiagonals,
4+
solve_decomposed_tridiagonal,
5+
split_solve_tridiagonal,
6+
)
7+
from pytensor.tensor.basic import diagonal
38
from pytensor.tensor.blockwise import Blockwise
49
from pytensor.tensor.elemwise import DimShuffle
510
from pytensor.tensor.rewriting.basic import register_specialize
@@ -10,7 +15,9 @@
1015
@register_specialize
1116
@node_rewriter(tracks=[Blockwise])
1217
def batched_solve_decomposition(fgraph, node):
13-
if not(isinstance(node.op.core_op, Solve) and node.op.core_op.assume_a == "tridiagonal"):
18+
if not (
19+
isinstance(node.op.core_op, Solve) and node.op.core_op.assume_a == "tridiagonal"
20+
):
1421
return
1522

1623
a, b = node.inputs
@@ -20,8 +27,11 @@ def batched_solve_decomposition(fgraph, node):
2027
# Check if a is broadcasted in computing the output
2128
if not any(
2229
a_bcast and not b_bcast
23-
for a_bcast, b_bcast
24-
in zip(a.type.broadcastable[:batch_ndim], b.type.broadcastable[:batch_ndim], strict=True)
30+
for a_bcast, b_bcast in zip(
31+
a.type.broadcastable[:batch_ndim],
32+
b.type.broadcastable[:batch_ndim],
33+
strict=True,
34+
)
2535
):
2636
return
2737

@@ -32,7 +42,6 @@ def batched_solve_decomposition(fgraph, node):
3242
@register_specialize
3343
@node_rewriter([Blockwise])
3444
def reuse_lu_decomp_multiple_solves(fgraph, node):
35-
3645
if not isinstance(node.op.core_op, Solve):
3746
return None
3847

@@ -43,32 +52,61 @@ def reuse_lu_decomp_multiple_solves(fgraph, node):
4352
return None
4453

4554
def find_solve_clients(var):
46-
return [
47-
cl
48-
for cl, idx in fgraph.clients[var]
49-
if idx == 0
50-
and isinstance(cl.op, Blockwise)
51-
and isinstance(cl.op.core_op, Solve)
52-
and cl.op.core_op.assume_a == assume_a
53-
]
54-
55+
clients = []
56+
for cl, idx in fgraph.clients[var]:
57+
if (
58+
idx == 0
59+
and isinstance(cl.op, Blockwise)
60+
and isinstance(cl.op.core_op, Solve)
61+
):
62+
clients.append(cl)
63+
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
64+
# If it's a left expand_dims, recurse on the output
65+
clients.extend(find_solve_clients(cl.outputs[0]))
66+
return clients
5567

5668
[A, _] = node.inputs
5769
if A.owner is not None and isinstance(A.owner.op, DimShuffle):
58-
# FIXME: Don't consider if dimshuffle mixes batch and core dims
70+
# If this DimShuffle is more than left expand_dims / matrix transpose
71+
# We won't find "clients" again and will exit the rewrite
5972
[A] = A.owner.inputs
6073

61-
# Find Solve using A
62-
A_solve_clients = [(client, False) for client in find_solve_clients(A)]
74+
# Find Solve using A (or left expand_dims of A)
75+
A_direct_solve_clients = find_solve_clients(A)
6376

6477
# Find Solves using A.T
78+
A_transpose_solve_clients = []
6579
for cl, _ in fgraph.clients[A]:
6680
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
6781
A_T = cl.out
68-
A_solve_clients.extend((client, True) for client in find_solve_clients(A_T))
82+
A_transpose_solve_clients.extend(find_solve_clients(A_T))
83+
84+
if (len(A_direct_solve_clients) + len(A_transpose_solve_clients)) <= 1:
85+
# If we only have one client, we don't need to do anything
86+
# It could still be useful to make the diagonal extraction symbolically
87+
# In case they are being set in the graph as well (therefore unnecessary)
88+
return None
89+
90+
dl, d, du = (diagonal(A, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
6991

70-
A_decomp = decompose_of_solve_tridiagonal(A)
7192
replacements = {}
72-
for client, transpose in A_solve_clients:
73-
_, b = client.inputs
74-
return replacements
93+
if A_direct_solve_clients:
94+
A_direct_decomp = decompose_tridiagonals(dl, d, du)
95+
for client in A_direct_solve_clients:
96+
_, b = client.inputs
97+
b_ndim = client.op.core_op.b_ndim
98+
replacements[client.outputs[0]] = solve_decomposed_tridiagonal(
99+
A_direct_decomp, b, b_ndim=b_ndim
100+
)
101+
102+
if A_transpose_solve_clients:
103+
# We just need to swap the off-diagonals
104+
A_transpose_decomp = decompose_tridiagonals(du, d, dl)
105+
for client in A_transpose_solve_clients:
106+
_, b = client.inputs
107+
b_ndim = client.op.core_op.b_ndim
108+
replacements[client.outputs[0]] = solve_decomposed_tridiagonal(
109+
A_transpose_decomp, b, b_ndim=b_ndim
110+
)
111+
112+
return replacements

pytensor/tensor/_linalg/solve/tridiagonal.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
import scipy
21
import numpy as np
2+
import scipy
33
from scipy.linalg import get_lapack_funcs
44

5-
from pytensor.graph import Op, Apply
5+
from pytensor.graph import Apply, Op
66
from pytensor.tensor.basic import as_tensor, diagonal
7-
from pytensor.tensor.type import tensor, vector
87
from pytensor.tensor.blockwise import Blockwise
98
from pytensor.tensor.slinalg import Solve
9+
from pytensor.tensor.type import tensor, vector
1010

1111

1212
class LUFactorTridiagonal(Op):
1313
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
14-
__props__ = ("overwrite_dl", "overwrite_d", "overwrite_du",)
14+
15+
__props__ = (
16+
"overwrite_dl",
17+
"overwrite_d",
18+
"overwrite_du",
19+
)
1520
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
1621

1722
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
@@ -29,11 +34,8 @@ def make_node(self, dl, d, du):
2934
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
3035
n = (
3136
ndl + 1
32-
if ndl is not None else (
33-
nd if nd is not None else (
34-
ndu + 1 if ndu is not None else None
35-
)
36-
)
37+
if ndl is not None
38+
else (nd if nd is not None else (ndu + 1 if ndu is not None else None))
3739
)
3840
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
3941
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
@@ -63,6 +65,7 @@ def perform(self, node, inputs, output_storage):
6365

6466
class SolveLUFactorTridiagonal(Op):
6567
"""Solve a system of linear equations with a tridiagonal coefficient matrix."""
68+
6669
__props__ = ("b_ndim", "overwrite_b")
6770

6871
def __init__(self, b_ndim: int, overwrite_b=False):
@@ -84,21 +87,30 @@ def make_node(self, dl, d, du, du2, ipiv, b):
8487
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
8588
raise ValueError("Inputs must be vectors")
8689

87-
ndl, nd, ndu, ndu2, nipiv = (inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv))
90+
ndl, nd, ndu, ndu2, nipiv = (
91+
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
92+
)
8893
nb = b.type.shape[0]
8994
n = (
9095
ndl + 1
91-
if ndl is not None else (
92-
nd if nd is not None else (
93-
ndu + 1 if ndu is not None else (
94-
ndu2 + 2 if ndu2 is not None else (
95-
nipiv if nipiv is not None else nb
96-
)
96+
if ndl is not None
97+
else (
98+
nd
99+
if nd is not None
100+
else (
101+
ndu + 1
102+
if ndu is not None
103+
else (
104+
ndu2 + 2
105+
if ndu2 is not None
106+
else (nipiv if nipiv is not None else nb)
97107
)
98108
)
99109
)
100110
)
101-
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)]
111+
dummy_arrays = [
112+
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
113+
]
102114
# Seems to always be float64?
103115
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
104116
if self.b_ndim == 1:
@@ -111,14 +123,13 @@ def make_node(self, dl, d, du, du2, ipiv, b):
111123

112124
def perform(self, node, inputs, output_storage):
113125
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
114-
x, _ = gttrs(
115-
*inputs, overwrite_b=self.overwrite_b
116-
)
126+
x, _ = gttrs(*inputs, overwrite_b=self.overwrite_b)
117127
output_storage[0][0] = x
118128

119129

120130
class SolveTridiagonal(Op):
121131
"""Solve a system of linear equations with a tridiagonal dense matrix."""
132+
122133
__props__ = ("b_ndim", "overwrite_b")
123134

124135
def __init__(self, *, b_ndim: int, overwrite_b: bool = False):
@@ -141,7 +152,9 @@ def make_node(self, dl, d, du, b):
141152
raise TypeError("Diagonals must have the same dtype")
142153

143154
if b.type.ndim != self.b_ndim:
144-
raise ValueError(f"Number of dimensions of b does not match promised {self.b_ndim}")
155+
raise ValueError(
156+
f"Number of dimensions of b does not match promised {self.b_ndim}"
157+
)
145158

146159
out_dtype = scipy.linalg.solve(
147160
np.eye((3), dtype=d.type.dtype),
@@ -156,13 +169,14 @@ def L_op(self, node, inputs, outputs, output_grads):
156169

157170
def perform(self, node, inputs, output_storage):
158171
[dl, d, du, b] = inputs
159-
_gttrf, _gttrs = get_lapack_funcs(('gttrf', 'gttrs'), dtype=node.outputs[0].type.dtype)
172+
_gttrf, _gttrs = get_lapack_funcs(
173+
("gttrf", "gttrs"), dtype=node.outputs[0].type.dtype
174+
)
160175

161176
dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du)
162177
x, _ = _gttrs(dl, d, du, du2, ipiv, b, overwrite_b=self.overwrite_b)
163178
output_storage[0][0] = x
164179

165-
166180
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
167181
if 3 not in allowed_inplace_inputs:
168182
return self
@@ -186,6 +200,7 @@ def solve_tridiagonal_from_full_A_b(a, b, b_ndim: int, transposed: bool):
186200
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
187201
return Blockwise(SolveTridiagonal(b_ndim=b_ndim))(dl, d, du)
188202

203+
189204
def split_solve_tridiagonal(node):
190205
"""Split a generic solve tridiagonal system into the 3 atomic steps:
191206
1. Diagonal extractions
@@ -198,11 +213,21 @@ def split_solve_tridiagonal(node):
198213
core_op = node.op.core_op
199214
assert isinstance(core_op, Solve) and core_op.assume_a == "tridiagonal"
200215
a, b = node.inputs
201-
dl, d, du, du2, ipiv = decompose_of_solve_tridiagonal(a)
202-
return Blockwise(SolveLUFactorTridiagonal(b_ndim=node.op.core_op.b_ndim))(dl, d, du, du2, ipiv, b)
216+
a_decomp = decompose_of_solve_tridiagonal(a)
217+
return solve_decomposed_tridiagonal(a_decomp, b, b_ndim=core_op.b_ndim)
218+
203219

204220
def decompose_of_solve_tridiagonal(a):
205221
# Return the decomposition of A implied by a solve tridiagonal
206222
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
207223
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
208224
return dl, d, du, du2, ipiv
225+
226+
227+
def decompose_tridiagonals(dl, d, du):
228+
return Blockwise(LUFactorTridiagonal())(dl, d, du)
229+
230+
231+
def solve_decomposed_tridiagonal(a_diagonals, b, *, b_ndim: int):
232+
dl, d, du, du2, ipiv = a_diagonals
233+
return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim))(dl, d, du, du2, ipiv, b)

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
7575
if ndims < 2:
7676
return False
7777
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
78+
79+
# Allow expand_dims on the left of the transpose
80+
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
81+
transpose_order = (
82+
*(["x"] * diff),
83+
*transpose_order,
84+
)
7885
return node.op.new_order == transpose_order
7986
return False
8087

0 commit comments

Comments
 (0)