Skip to content

Commit 1e4972e

Browse files
authored
Fix QR and LQ pullback for rank 0 blocks (#223)
* Replace `findlast` with `count` * replace count with searchsortedlast * Revert back to `findlast` for QR and LQ
1 parent a4eb3f3 commit 1e4972e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
219219
Sp = view(S, 1:p)
220220

221221
# rank
222-
r = count(>(tol), S)
222+
r = searchsortedlast(S, tol; rev=true)
223223

224224
# compute antihermitian part of projection of ΔU and ΔV onto U and V
225225
# also already subtract this projection from ΔU and ΔV
@@ -376,7 +376,7 @@ end
376376
function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
377377
tol::Real=default_pullback_gaugetol(R))
378378
Rd = view(R, diagind(R))
379-
p = findlast(>=(tol) abs, Rd)
379+
p = something(findlast((tol) abs, Rd), 0)
380380
m, n = size(R)
381381

382382
Q1 = view(Q, :, 1:p)
@@ -427,7 +427,7 @@ end
427427
function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
428428
tol::Real=default_pullback_gaugetol(L))
429429
Ld = view(L, diagind(L))
430-
p = findlast(>=(tol) abs, Ld)
430+
p = something(findlast((tol) abs, Ld), 0)
431431
m, n = size(L)
432432

433433
L1 = view(L, :, 1:p)

0 commit comments

Comments
 (0)