Skip to content

Commit 77a9b0b

Browse files
committed
Tests pass locally
1 parent 2c81058 commit 77a9b0b

File tree

2 files changed

+79
-30
lines changed

2 files changed

+79
-30
lines changed

src/TriangularSolve.jl

+44-24
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ end
267267
# Each iter:
268268
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
269269
for nk SafeCloseOpen(n) # nmuladd
270-
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})(nk, n))
270+
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nk, n)))
271271
Base.Cartesian.@nexprs $W c ->
272272
A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c)
273273
end
@@ -320,7 +320,7 @@ end
320320
end
321321
@generated function ldiv_solve_W!(
322322
spc,
323-
s,
323+
spa,
324324
spu,
325325
n,
326326
::StaticInt{W},
@@ -354,16 +354,18 @@ end
354354
# Each iter:
355355
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
356356
for nk SafeCloseOpen(n) # nmuladd
357-
U_ki = vload(spu, (nk, $(MM{W}(z))))
357+
U_ki = vload(spu, (nk, $(MM{W})(n)))
358358
Base.Cartesian.@nexprs $W c ->
359359
A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c)
360360
end
361+
# Base.Cartesian.@nexprs $W c -> @show A11_c
361362
# solve AU wants us to transpose
362363
# We then have column-major multiplies
363364
# take A[(u-1)*W,u*W), [0,W)]
364365
X = VectorizationBase.transpose_vecunroll(
365366
VecUnroll(Base.Cartesian.@ntuple $W A11)
366367
)
368+
# @show X
367369
C_u = solve_AU(X, spu, n, $(Val(UNIT)))
368370
vstore!(spc, C_u, $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n)))
369371
end
@@ -402,13 +404,13 @@ end
402404
A11 =
403405
getfield(vload(spa, $(Unroll{1,1,R,2,W,zero(UInt),1})(($z, n))), :data)
404406
# The `W` rows
405-
Base.Cartesian.@nexprs $W r -> A11_r = getfield(A11, r)
407+
Base.Cartesian.@nexprs $R r -> A11_r = getfield(A11, r)
406408
# compute
407409
# A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k})
408410
# Each iter:
409411
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
410412
for nk SafeCloseOpen(n) # nmuladd
411-
U_ki = vload(spu, (nk, $(MM{W}(z))))
413+
U_ki = vload(spu, (nk, $(MM{W})(n)))
412414
Base.Cartesian.@nexprs $R r ->
413415
A11_r = vfnmadd_fast(U_ki, vload(spc, (static(r - 1), nk)), A11_r)
414416
end
@@ -432,13 +434,13 @@ end
432434
push!(q.args, q2)
433435
q3 = if R == Wpad
434436
quote
435-
i = $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n))
437+
i = $(Unroll{2,1,W,1,Wpad,zero(UInt),1})(($z, n))
436438
vstore!(spc, C_u, i)
437439
end
438440
else
439441
quote
440442
mask = VectorizationBase.mask($WS, $(static(R)))
441-
i = $(Unroll{2,1,W,1,W,(-1 % UInt),1})(($z, n))
443+
i = $(Unroll{2,1,W,1,Wpad,(-1 % UInt),1})(($z, n))
442444
vstore!(spc, C_u, i, mask)
443445
end
444446
end
@@ -890,7 +892,8 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
890892
Core.ifelse(block == Nblock - 1, Mrem, mtb),
891893
N,
892894
Val{UNIT}(),
893-
static(XC)static(XA)
895+
static(XC),
896+
static(XA)
894897
)
895898
end
896899
end
@@ -1000,29 +1003,35 @@ end
10001003
N,
10011004
m,
10021005
Nr,
1003-
::Val{W},
1006+
::StaticInt{W},
10041007
::Val{UNIT},
1005-
::Val{r}
1008+
::StaticInt{r}
10061009
) where {W,UNIT,r}
10071010
r <= 0 && throw("Remainder of `<= 0` shouldn't be called, but had $r.")
10081011
r >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $r.")
10091012
if r == 1
1010-
vlxj = :(vload(spc, (M - 1, j)))
1011-
if !UNIT
1012-
vlxj = :($vlxj / vload(spu, (j, j)))
1013+
z = static(0)
1014+
vlxj = :(vload(spc, ($z, j)))
1015+
if UNIT
1016+
vlxj = :(xj = $vlxj)
1017+
else
1018+
vlxj = quote
1019+
xj = $vlxj / vload(spu, (j, j))
1020+
vstore!(spc, xj, ($z, j))
1021+
end
10131022
end
10141023
quote
10151024
if pointer(spc) != pointer(spa)
10161025
for n = 0:N-1
1017-
vstore!(spc, vload(spa, (M - 1, n)), (M - 1, n))
1026+
vstore!(spc, vload(spa, ($z, n)), ($z, n))
10181027
end
10191028
end
10201029
for j = 0:N-1
1021-
xj = $vlxj
1030+
$vlxj
10221031
for i = (j+1):N-1
1023-
xi = vload(spc, (M - 1, i))
1032+
xi = vload(spc, ($z, i))
10241033
Uji = vload(spu, (j, i))
1025-
vstore!(spc, xi - xj * Uji, (M - 1, i))
1034+
vstore!(spc, xi - xj * Uji, ($z, i))
10261035
end
10271036
end
10281037
end
@@ -1033,14 +1042,14 @@ end
10331042
n = Nr # non factor of W remainder
10341043
if n > 0
10351044
mask = $(VectorizationBase.mask(WS, r))
1036-
BdivU_small_kern!(spc, nothing, spa, spu, n, mask, Val(UNIT))
1045+
BdivU_small_kern!(spc, nothing, spa, spu, n, mask, $(Val(UNIT)))
10371046
end
10381047
# while n < N - $(W * U - 1)
1039-
# ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(w))
1048+
# ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(r))
10401049
# n += $(W * U)
10411050
# end
10421051
while n != N
1043-
ldiv_solve_W!(spc, spa, spu, n, $WS, Val(UNIT), Val(w))
1052+
ldiv_solve_W!(spc, spa, spu, n, $WS, $(Val(UNIT)), $(StaticInt(r)))
10441053
n += $W
10451054
end
10461055
end
@@ -1054,16 +1063,26 @@ end
10541063
N,
10551064
m,
10561065
Nr,
1057-
::Val{W},
1066+
::StaticInt{W},
10581067
# ::Val{U},
10591068
::Val{UNIT}
10601069
) where {W,UNIT}
10611070
WS = static(W)
10621071
# US = static(U)
10631072
quote
10641073
$(Expr(:meta, :inline))
1065-
Base.Cartesian.@nif $W w -> m == M - w w ->
1066-
ldiv_remainder!(spc, spa, spu, M, N, m, Nr, $WS, $(Val(UNIT)), Val(w))
1074+
Base.Cartesian.@nif $W w -> m == M - w w -> ldiv_remainder!(
1075+
spc,
1076+
spa,
1077+
spu,
1078+
M,
1079+
N,
1080+
m,
1081+
Nr,
1082+
$WS,
1083+
$(Val(UNIT)),
1084+
StaticInt(w)
1085+
)
10671086
end
10681087
end
10691088

@@ -1087,11 +1106,12 @@ function rdiv_U!(
10871106
MU = UF > 1 ? M : 0
10881107
Nd, Nr = VectorizationBase.vdivrem(N, WS)
10891108
m = 0
1109+
# @show M,N
10901110
# m, no remainder
10911111
while m < M - WS + 1
10921112
n = Nr # non factor of W remainder
10931113
if n > 0
1094-
BdivU_small_kern_u!(spc, nothing, spa, spu, n, Val(1), Val(UNIT))
1114+
BdivU_small_kern_u!(spc, nothing, spa, spu, n, StaticInt(1), Val(UNIT))
10951115
end
10961116
while n < N - (WU - 1)
10971117
ldiv_solve_W_u!(spc, spa, spu, n, WS, UF, Val(UNIT))

test/runtests.jl

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
using TriangularSolve, LinearAlgebra
22
using Test
33

4+
function check_box_for_nans(A, M, N)
5+
# blocks start at 17, and are MxN
6+
@test all(isnan, @view(A[1:16, :]))
7+
@test all(isnan, @view(A[17+M:end, :]))
8+
@test all(isnan, @view(A[17:16+M, 1:16]))
9+
@test all(isnan, @view(A[17:16+M, 17+N:end]))
10+
end
11+
412
function test_solve(::Type{T}) where {T}
5-
for n 1:(T === Float32 ? 100 : 200)
13+
maxN = (T === Float32 ? 100 : 200)
14+
maxM = maxN + 10
15+
AA = fill(T(NaN), maxM + 32, maxM + 32)
16+
RR = fill(T(NaN), maxM + 32, maxM + 32)
17+
BB = fill(T(NaN), maxN + 32, maxN + 32)
18+
for n 1:maxN
619
@show n
720
for m max(1, n - 10):n+10
8-
A = rand(T, m, n)
9-
res = similar(A)
10-
B = rand(T, n, n) + I
21+
A = @view AA[17:16+m, 17:16+n]
22+
res = @view RR[17:16+m, 17:16+n]
23+
B = @view BB[17:16+n, 17:16+n]
24+
25+
A .= rand.(T)
26+
B .= rand.(T)
27+
@view(B[diagind(B)]) .+= one(T)
28+
1129
@test TriangularSolve.rdiv!(res, A, UpperTriangular(B)) *
1230
UpperTriangular(B) A
1331
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B)) *
@@ -16,8 +34,15 @@ function test_solve(::Type{T}) where {T}
1634
UpperTriangular(B) A
1735
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B), Val(false)) *
1836
UnitUpperTriangular(B) A
19-
A = rand(T, n, m)
20-
res = similar(A)
37+
38+
check_box_for_nans(RR, m, n)
39+
res .= NaN
40+
A .= NaN
41+
42+
A = @view AA[17:16+n, 17:16+m]
43+
res = @view RR[17:16+n, 17:16+m]
44+
A .= rand.(T)
45+
2146
@test LowerTriangular(B) *
2247
TriangularSolve.ldiv!(res, LowerTriangular(B), A) A
2348
@test UnitLowerTriangular(B) *
@@ -27,6 +52,10 @@ function test_solve(::Type{T}) where {T}
2752
@test UnitLowerTriangular(B) *
2853
TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A, Val(false))
2954
A
55+
check_box_for_nans(RR, n, m)
56+
res .= NaN
57+
A .= NaN
58+
B .= NaN
3059
end
3160
end
3261
end

0 commit comments

Comments
 (0)