1
- import scipy
2
1
import numpy as np
2
+ import scipy
3
3
from scipy .linalg import get_lapack_funcs
4
4
5
- from pytensor .graph import Op , Apply
5
+ from pytensor .graph import Apply , Op
6
6
from pytensor .tensor .basic import as_tensor , diagonal
7
- from pytensor .tensor .type import tensor , vector
8
7
from pytensor .tensor .blockwise import Blockwise
9
8
from pytensor .tensor .slinalg import Solve
9
+ from pytensor .tensor .type import tensor , vector
10
10
11
11
12
12
class LUFactorTridiagonal (Op ):
13
13
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
14
- __props__ = ("overwrite_dl" , "overwrite_d" , "overwrite_du" ,)
14
+
15
+ __props__ = (
16
+ "overwrite_dl" ,
17
+ "overwrite_d" ,
18
+ "overwrite_du" ,
19
+ )
15
20
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
16
21
17
22
def __init__ (self , overwrite_dl = False , overwrite_d = False , overwrite_du = False ):
@@ -29,11 +34,8 @@ def make_node(self, dl, d, du):
29
34
ndl , nd , ndu = (inp .type .shape [- 1 ] for inp in (dl , d , du ))
30
35
n = (
31
36
ndl + 1
32
- if ndl is not None else (
33
- nd if nd is not None else (
34
- ndu + 1 if ndu is not None else None
35
- )
36
- )
37
+ if ndl is not None
38
+ else (nd if nd is not None else (ndu + 1 if ndu is not None else None ))
37
39
)
38
40
dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du )]
39
41
out_dtype = get_lapack_funcs ("gttrf" , dummy_arrays ).dtype
@@ -63,6 +65,7 @@ def perform(self, node, inputs, output_storage):
63
65
64
66
class SolveLUFactorTridiagonal (Op ):
65
67
"""Solve a system of linear equations with a tridiagonal coefficient matrix."""
68
+
66
69
__props__ = ("b_ndim" , "overwrite_b" )
67
70
68
71
def __init__ (self , b_ndim : int , overwrite_b = False ):
@@ -84,21 +87,30 @@ def make_node(self, dl, d, du, du2, ipiv, b):
84
87
if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv )):
85
88
raise ValueError ("Inputs must be vectors" )
86
89
87
- ndl , nd , ndu , ndu2 , nipiv = (inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv ))
90
+ ndl , nd , ndu , ndu2 , nipiv = (
91
+ inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv )
92
+ )
88
93
nb = b .type .shape [0 ]
89
94
n = (
90
95
ndl + 1
91
- if ndl is not None else (
92
- nd if nd is not None else (
93
- ndu + 1 if ndu is not None else (
94
- ndu2 + 2 if ndu2 is not None else (
95
- nipiv if nipiv is not None else nb
96
- )
96
+ if ndl is not None
97
+ else (
98
+ nd
99
+ if nd is not None
100
+ else (
101
+ ndu + 1
102
+ if ndu is not None
103
+ else (
104
+ ndu2 + 2
105
+ if ndu2 is not None
106
+ else (nipiv if nipiv is not None else nb )
97
107
)
98
108
)
99
109
)
100
110
)
101
- dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du , du2 , ipiv )]
111
+ dummy_arrays = [
112
+ np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du , du2 , ipiv )
113
+ ]
102
114
# Seems to always be float64?
103
115
out_dtype = get_lapack_funcs ("gttrs" , dummy_arrays ).dtype
104
116
if self .b_ndim == 1 :
@@ -111,14 +123,13 @@ def make_node(self, dl, d, du, du2, ipiv, b):
111
123
112
124
def perform (self , node , inputs , output_storage ):
113
125
gttrs = get_lapack_funcs ("gttrs" , dtype = node .outputs [0 ].type .dtype )
114
- x , _ = gttrs (
115
- * inputs , overwrite_b = self .overwrite_b
116
- )
126
+ x , _ = gttrs (* inputs , overwrite_b = self .overwrite_b )
117
127
output_storage [0 ][0 ] = x
118
128
119
129
120
130
class SolveTridiagonal (Op ):
121
131
"""Solve a system of linear equations with a tridiagonal dense matrix."""
132
+
122
133
__props__ = ("b_ndim" , "overwrite_b" )
123
134
124
135
def __init__ (self , * , b_ndim : int , overwrite_b : bool = False ):
@@ -141,7 +152,9 @@ def make_node(self, dl, d, du, b):
141
152
raise TypeError ("Diagonals must have the same dtype" )
142
153
143
154
if b .type .ndim != self .b_ndim :
144
- raise ValueError (f"Number of dimensions of b does not match promised { self .b_ndim } " )
155
+ raise ValueError (
156
+ f"Number of dimensions of b does not match promised { self .b_ndim } "
157
+ )
145
158
146
159
out_dtype = scipy .linalg .solve (
147
160
np .eye ((3 ), dtype = d .type .dtype ),
@@ -156,13 +169,14 @@ def L_op(self, node, inputs, outputs, output_grads):
156
169
157
170
def perform (self , node , inputs , output_storage ):
158
171
[dl , d , du , b ] = inputs
159
- _gttrf , _gttrs = get_lapack_funcs (('gttrf' , 'gttrs' ), dtype = node .outputs [0 ].type .dtype )
172
+ _gttrf , _gttrs = get_lapack_funcs (
173
+ ("gttrf" , "gttrs" ), dtype = node .outputs [0 ].type .dtype
174
+ )
160
175
161
176
dl , d , du , du2 , ipiv , _ = _gttrf (dl , d , du )
162
177
x , _ = _gttrs (dl , d , du , du2 , ipiv , b , overwrite_b = self .overwrite_b )
163
178
output_storage [0 ][0 ] = x
164
179
165
-
166
180
def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
167
181
if 3 not in allowed_inplace_inputs :
168
182
return self
@@ -186,6 +200,7 @@ def solve_tridiagonal_from_full_A_b(a, b, b_ndim: int, transposed: bool):
186
200
dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
187
201
return Blockwise (SolveTridiagonal (b_ndim = b_ndim ))(dl , d , du )
188
202
203
+
189
204
def split_solve_tridiagonal (node ):
190
205
"""Split a generic solve tridiagonal system into the 3 atomic steps:
191
206
1. Diagonal extractions
@@ -198,11 +213,21 @@ def split_solve_tridiagonal(node):
198
213
core_op = node .op .core_op
199
214
assert isinstance (core_op , Solve ) and core_op .assume_a == "tridiagonal"
200
215
a , b = node .inputs
201
- dl , d , du , du2 , ipiv = decompose_of_solve_tridiagonal (a )
202
- return Blockwise (SolveLUFactorTridiagonal (b_ndim = node .op .core_op .b_ndim ))(dl , d , du , du2 , ipiv , b )
216
+ a_decomp = decompose_of_solve_tridiagonal (a )
217
+ return solve_decomposed_tridiagonal (a_decomp , b , b_ndim = core_op .b_ndim )
218
+
203
219
204
220
def decompose_of_solve_tridiagonal (a ):
205
221
# Return the decomposition of A implied by a solve tridiagonal
206
222
dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
207
223
dl , d , du , du2 , ipiv = Blockwise (LUFactorTridiagonal ())(dl , d , du )
208
224
return dl , d , du , du2 , ipiv
225
+
226
+
227
+ def decompose_tridiagonals (dl , d , du ):
228
+ return Blockwise (LUFactorTridiagonal ())(dl , d , du )
229
+
230
+
231
+ def solve_decomposed_tridiagonal (a_diagonals , b , * , b_ndim : int ):
232
+ dl , d , du , du2 , ipiv = a_diagonals
233
+ return Blockwise (SolveLUFactorTridiagonal (b_ndim = b_ndim ))(dl , d , du , du2 , ipiv , b )
0 commit comments