Skip to content

Commit 025c025

Browse files
Revert "More carefully handle lower flag in Solve"
This reverts commit 388e93e.
1 parent 8a87fc4 commit 025c025

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,43 +100,40 @@ def find_solve_clients(var, assume_a):
100100
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
101101
# If it's a left expand_dims, recurse on the output
102102
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
103-
104103
return clients
105104

106105
assume_a = node.op.core_op.assume_a
107106

108107
if assume_a not in allowed_assume_a:
109108
return None
110109

111-
root_A, root_A_transposed = get_root_A(node.inputs[0])
110+
A, _ = get_root_A(node.inputs[0])
112111

113112
# Find Solve using A (or left expand_dims of A)
114113
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
115114
# that to the A_decomp outputs
116-
root_A_solve_clients_and_transpose = [
117-
(client, False) for client in find_solve_clients(root_A, assume_a)
115+
A_solve_clients_and_transpose = [
116+
(client, False) for client in find_solve_clients(A, assume_a)
118117
]
119118

120119
# Find Solves using A.T
121-
for cl, _ in fgraph.clients[root_A]:
120+
for cl, _ in fgraph.clients[A]:
122121
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
123122
A_T = cl.out
124-
root_A_solve_clients_and_transpose.extend(
123+
A_solve_clients_and_transpose.extend(
125124
(client, True) for client in find_solve_clients(A_T, assume_a)
126125
)
127126

128-
if not eager and len(root_A_solve_clients_and_transpose) == 1:
127+
if not eager and len(A_solve_clients_and_transpose) == 1:
129128
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
130129
# That's a "reuse" inside the inner vectorized loop
131130
batch_ndim = node.op.batch_ndim(node)
132-
(client, _) = root_A_solve_clients_and_transpose[0]
133-
134-
A, b = client.inputs
135-
131+
(client, _) = A_solve_clients_and_transpose[0]
132+
original_A, b = client.inputs
136133
if not any(
137134
a_bcast and not b_bcast
138135
for a_bcast, b_bcast in zip(
139-
A.type.broadcastable[:batch_ndim],
136+
original_A.type.broadcastable[:batch_ndim],
140137
b.type.broadcastable[:batch_ndim],
141138
strict=True,
142139
)
@@ -145,27 +142,19 @@ def find_solve_clients(var, assume_a):
145142

146143
# If any Op had check_finite=True, we also do it for the LU decomposition
147144
check_finite_decomp = False
148-
for client, _ in root_A_solve_clients_and_transpose:
145+
for client, _ in A_solve_clients_and_transpose:
149146
if client.op.core_op.check_finite:
150147
check_finite_decomp = True
151148
break
152149

153-
(first_solve, transposed) = root_A_solve_clients_and_transpose[0]
154-
lower = first_solve.op.core_op.lower
155-
if transposed:
156-
lower = not lower
157-
150+
lower = node.op.core_op.lower
158151
A_decomp = decompose_A(
159-
root_A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
152+
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
160153
)
161154

162155
replacements = {}
163-
for client, transposed in root_A_solve_clients_and_transpose:
156+
for client, transposed in A_solve_clients_and_transpose:
164157
_, b = client.inputs
165-
lower = client.op.core_op.lower
166-
if transposed:
167-
lower = not lower
168-
169158
new_x = solve_decomposed_system(
170159
A_decomp,
171160
b,

0 commit comments

Comments
 (0)