@@ -100,43 +100,40 @@ def find_solve_clients(var, assume_a):
100
100
elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
101
101
# If it's a left expand_dims, recurse on the output
102
102
clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
103
-
104
103
return clients
105
104
106
105
assume_a = node .op .core_op .assume_a
107
106
108
107
if assume_a not in allowed_assume_a :
109
108
return None
110
109
111
- root_A , root_A_transposed = get_root_A (node .inputs [0 ])
110
+ A , _ = get_root_A (node .inputs [0 ])
112
111
113
112
# Find Solve using A (or left expand_dims of A)
114
113
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
115
114
# that to the A_decomp outputs
116
- root_A_solve_clients_and_transpose = [
117
- (client , False ) for client in find_solve_clients (root_A , assume_a )
115
+ A_solve_clients_and_transpose = [
116
+ (client , False ) for client in find_solve_clients (A , assume_a )
118
117
]
119
118
120
119
# Find Solves using A.T
121
- for cl , _ in fgraph .clients [root_A ]:
120
+ for cl , _ in fgraph .clients [A ]:
122
121
if isinstance (cl .op , DimShuffle ) and is_matrix_transpose (cl .out ):
123
122
A_T = cl .out
124
- root_A_solve_clients_and_transpose .extend (
123
+ A_solve_clients_and_transpose .extend (
125
124
(client , True ) for client in find_solve_clients (A_T , assume_a )
126
125
)
127
126
128
- if not eager and len (root_A_solve_clients_and_transpose ) == 1 :
127
+ if not eager and len (A_solve_clients_and_transpose ) == 1 :
129
128
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
130
129
# That's a "reuse" inside the inner vectorized loop
131
130
batch_ndim = node .op .batch_ndim (node )
132
- (client , _ ) = root_A_solve_clients_and_transpose [0 ]
133
-
134
- A , b = client .inputs
135
-
131
+ (client , _ ) = A_solve_clients_and_transpose [0 ]
132
+ original_A , b = client .inputs
136
133
if not any (
137
134
a_bcast and not b_bcast
138
135
for a_bcast , b_bcast in zip (
139
- A .type .broadcastable [:batch_ndim ],
136
+ original_A .type .broadcastable [:batch_ndim ],
140
137
b .type .broadcastable [:batch_ndim ],
141
138
strict = True ,
142
139
)
@@ -145,27 +142,19 @@ def find_solve_clients(var, assume_a):
145
142
146
143
# If any Op had check_finite=True, we also do it for the LU decomposition
147
144
check_finite_decomp = False
148
- for client , _ in root_A_solve_clients_and_transpose :
145
+ for client , _ in A_solve_clients_and_transpose :
149
146
if client .op .core_op .check_finite :
150
147
check_finite_decomp = True
151
148
break
152
149
153
- (first_solve , transposed ) = root_A_solve_clients_and_transpose [0 ]
154
- lower = first_solve .op .core_op .lower
155
- if transposed :
156
- lower = not lower
157
-
150
+ lower = node .op .core_op .lower
158
151
A_decomp = decompose_A (
159
- root_A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
152
+ A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
160
153
)
161
154
162
155
replacements = {}
163
- for client , transposed in root_A_solve_clients_and_transpose :
156
+ for client , transposed in A_solve_clients_and_transpose :
164
157
_ , b = client .inputs
165
- lower = client .op .core_op .lower
166
- if transposed :
167
- lower = not lower
168
-
169
158
new_x = solve_decomposed_system (
170
159
A_decomp ,
171
160
b ,
0 commit comments