267
267
# Each iter:
268
268
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
269
269
for nk ∈ SafeCloseOpen (n) # nmuladd
270
- U_ki = vload (spu, $ (Unroll{2 ,W,U,2 ,W,zero (UInt),1 })(nk, n))
270
+ U_ki = vload (spu, $ (Unroll{2 ,W,U,2 ,W,zero (UInt),1 })(( nk, n) ))
271
271
Base. Cartesian. @nexprs $ W c ->
272
272
A11_c = vfnmadd_fast (U_ki, vload (spc, (static (c - 1 ), nk)), A11_c)
273
273
end
320
320
end
321
321
@generated function ldiv_solve_W! (
322
322
spc,
323
- s ,
323
+ spa ,
324
324
spu,
325
325
n,
326
326
:: StaticInt{W} ,
@@ -354,16 +354,18 @@ end
354
354
# Each iter:
355
355
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
356
356
for nk ∈ SafeCloseOpen (n) # nmuladd
357
- U_ki = vload (spu, (nk, $ (MM {W} (z) )))
357
+ U_ki = vload (spu, (nk, $ (MM{W})(n )))
358
358
Base. Cartesian. @nexprs $ W c ->
359
359
A11_c = vfnmadd_fast (U_ki, vload (spc, (static (c - 1 ), nk)), A11_c)
360
360
end
361
+ # Base.Cartesian.@nexprs $W c -> @show A11_c
361
362
# solve AU wants us to transpose
362
363
# We then have column-major multiplies
363
364
# take A[(u-1)*W,u*W), [0,W)]
364
365
X = VectorizationBase. transpose_vecunroll (
365
366
VecUnroll (Base. Cartesian. @ntuple $ W A11)
366
367
)
368
+ # @show X
367
369
C_u = solve_AU (X, spu, n, $ (Val (UNIT)))
368
370
vstore! (spc, C_u, $ (Unroll{2 ,1 ,W,1 ,W,zero (UInt),1 })(($ z, n)))
369
371
end
@@ -402,13 +404,13 @@ end
402
404
A11 =
403
405
getfield (vload (spa, $ (Unroll{1 ,1 ,R,2 ,W,zero (UInt),1 })(($ z, n))), :data )
404
406
# The `W` rows
405
- Base. Cartesian. @nexprs $ W r -> A11_r = getfield (A11, r)
407
+ Base. Cartesian. @nexprs $ R r -> A11_r = getfield (A11, r)
406
408
# compute
407
409
# A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k})
408
410
# Each iter:
409
411
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
410
412
for nk ∈ SafeCloseOpen (n) # nmuladd
411
- U_ki = vload (spu, (nk, $ (MM {W} (z) )))
413
+ U_ki = vload (spu, (nk, $ (MM{W})(n )))
412
414
Base. Cartesian. @nexprs $ R r ->
413
415
A11_r = vfnmadd_fast (U_ki, vload (spc, (static (r - 1 ), nk)), A11_r)
414
416
end
@@ -432,13 +434,13 @@ end
432
434
push! (q. args, q2)
433
435
q3 = if R == Wpad
434
436
quote
435
- i = $ (Unroll{2 ,1 ,W,1 ,W ,zero (UInt),1 })(($ z, n))
437
+ i = $ (Unroll{2 ,1 ,W,1 ,Wpad ,zero (UInt),1 })(($ z, n))
436
438
vstore! (spc, C_u, i)
437
439
end
438
440
else
439
441
quote
440
442
mask = VectorizationBase. mask ($ WS, $ (static (R)))
441
- i = $ (Unroll{2 ,1 ,W,1 ,W ,(- 1 % UInt),1 })(($ z, n))
443
+ i = $ (Unroll{2 ,1 ,W,1 ,Wpad ,(- 1 % UInt),1 })(($ z, n))
442
444
vstore! (spc, C_u, i, mask)
443
445
end
444
446
end
@@ -890,7 +892,8 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
890
892
Core. ifelse (block == Nblock - 1 , Mrem, mtb),
891
893
N,
892
894
Val {UNIT} (),
893
- static (XC)static (XA)
895
+ static (XC),
896
+ static (XA)
894
897
)
895
898
end
896
899
end
@@ -1000,29 +1003,35 @@ end
1000
1003
N,
1001
1004
m,
1002
1005
Nr,
1003
- :: Val {W} ,
1006
+ :: StaticInt {W} ,
1004
1007
:: Val{UNIT} ,
1005
- :: Val {r}
1008
+ :: StaticInt {r}
1006
1009
) where {W,UNIT,r}
1007
1010
r <= 0 && throw (" Remainder of `<= 0` shouldn't be called, but had $r ." )
1008
1011
r >= W && throw (" Reaminderof `>= $W ` shouldn't be called, but had $r ." )
1009
1012
if r == 1
1010
- vlxj = :(vload (spc, (M - 1 , j)))
1011
- if ! UNIT
1012
- vlxj = :($ vlxj / vload (spu, (j, j)))
1013
+ z = static (0 )
1014
+ vlxj = :(vload (spc, ($ z, j)))
1015
+ if UNIT
1016
+ vlxj = :(xj = $ vlxj)
1017
+ else
1018
+ vlxj = quote
1019
+ xj = $ vlxj / vload (spu, (j, j))
1020
+ vstore! (spc, xj, ($ z, j))
1021
+ end
1013
1022
end
1014
1023
quote
1015
1024
if pointer (spc) != pointer (spa)
1016
1025
for n = 0 : N- 1
1017
- vstore! (spc, vload (spa, (M - 1 , n)), (M - 1 , n))
1026
+ vstore! (spc, vload (spa, ($ z , n)), ($ z , n))
1018
1027
end
1019
1028
end
1020
1029
for j = 0 : N- 1
1021
- xj = $ vlxj
1030
+ $ vlxj
1022
1031
for i = (j+ 1 ): N- 1
1023
- xi = vload (spc, (M - 1 , i))
1032
+ xi = vload (spc, ($ z , i))
1024
1033
Uji = vload (spu, (j, i))
1025
- vstore! (spc, xi - xj * Uji, (M - 1 , i))
1034
+ vstore! (spc, xi - xj * Uji, ($ z , i))
1026
1035
end
1027
1036
end
1028
1037
end
@@ -1033,14 +1042,14 @@ end
1033
1042
n = Nr # non factor of W remainder
1034
1043
if n > 0
1035
1044
mask = $ (VectorizationBase. mask (WS, r))
1036
- BdivU_small_kern! (spc, nothing , spa, spu, n, mask, Val (UNIT))
1045
+ BdivU_small_kern! (spc, nothing , spa, spu, n, mask, $ ( Val (UNIT) ))
1037
1046
end
1038
1047
# while n < N - $(W * U - 1)
1039
- # ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(w ))
1048
+ # ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(r ))
1040
1049
# n += $(W * U)
1041
1050
# end
1042
1051
while n != N
1043
- ldiv_solve_W! (spc, spa, spu, n, $ WS, Val (UNIT), Val (w ))
1052
+ ldiv_solve_W! (spc, spa, spu, n, $ WS, $ ( Val (UNIT)), $ ( StaticInt (r) ))
1044
1053
n += $ W
1045
1054
end
1046
1055
end
@@ -1054,16 +1063,26 @@ end
1054
1063
N,
1055
1064
m,
1056
1065
Nr,
1057
- :: Val {W} ,
1066
+ :: StaticInt {W} ,
1058
1067
# ::Val{U},
1059
1068
:: Val{UNIT}
1060
1069
) where {W,UNIT}
1061
1070
WS = static (W)
1062
1071
# US = static(U)
1063
1072
quote
1064
1073
$ (Expr (:meta , :inline ))
1065
- Base. Cartesian. @nif $ W w -> m == M - w w ->
1066
- ldiv_remainder! (spc, spa, spu, M, N, m, Nr, $ WS, $ (Val (UNIT)), Val (w))
1074
+ Base. Cartesian. @nif $ W w -> m == M - w w -> ldiv_remainder! (
1075
+ spc,
1076
+ spa,
1077
+ spu,
1078
+ M,
1079
+ N,
1080
+ m,
1081
+ Nr,
1082
+ $ WS,
1083
+ $ (Val (UNIT)),
1084
+ StaticInt (w)
1085
+ )
1067
1086
end
1068
1087
end
1069
1088
@@ -1087,11 +1106,12 @@ function rdiv_U!(
1087
1106
MU = UF > 1 ? M : 0
1088
1107
Nd, Nr = VectorizationBase. vdivrem (N, WS)
1089
1108
m = 0
1109
+ # @show M,N
1090
1110
# m, no remainder
1091
1111
while m < M - WS + 1
1092
1112
n = Nr # non factor of W remainder
1093
1113
if n > 0
1094
- BdivU_small_kern_u! (spc, nothing , spa, spu, n, Val (1 ), Val (UNIT))
1114
+ BdivU_small_kern_u! (spc, nothing , spa, spu, n, StaticInt (1 ), Val (UNIT))
1095
1115
end
1096
1116
while n < N - (WU - 1 )
1097
1117
ldiv_solve_W_u! (spc, spa, spu, n, WS, UF, Val (UNIT))
0 commit comments