1
+ module LinearSolveBLISExt
2
+
3
+ using Libdl
4
+ using blis_jll
5
+ using LinearAlgebra
6
+ using LinearSolve
7
+
8
+ using LinearAlgebra: BlasInt, LU
9
+ using LinearAlgebra. LAPACK: require_one_based_indexing, chkfinite, chkstride1,
10
+ @blasfunc , chkargsok
11
+ using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval , LinearCache, SciMLBase
12
+
13
+ const global libblis = dlopen (blis_jll. blis_path)
14
+
15
+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
16
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
17
+ info = Ref {BlasInt} (),
18
+ check = false )
19
+ require_one_based_indexing (A)
20
+ check && chkfinite (A)
21
+ chkstride1 (A)
22
+ m, n = size (A)
23
+ lda = max (1 , stride (A, 2 ))
24
+ if isempty (ipiv)
25
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
26
+ end
27
+ ccall ((@blasfunc (zgetrf_), libblis), Cvoid,
28
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
29
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
30
+ m, n, A, lda, ipiv, info)
31
+ chkargsok (info[])
32
+ A, ipiv, info[], info # Error code is stored in LU factorization type
33
+ end
34
+
35
+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
36
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
37
+ info = Ref {BlasInt} (),
38
+ check = false )
39
+ require_one_based_indexing (A)
40
+ check && chkfinite (A)
41
+ chkstride1 (A)
42
+ m, n = size (A)
43
+ lda = max (1 , stride (A, 2 ))
44
+ if isempty (ipiv)
45
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
46
+ end
47
+ ccall ((@blasfunc (cgetrf_), libblis), Cvoid,
48
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
49
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
50
+ m, n, A, lda, ipiv, info)
51
+ chkargsok (info[])
52
+ A, ipiv, info[], info # Error code is stored in LU factorization type
53
+ end
54
+
55
+ function getrf! (A:: AbstractMatrix{<:Float64} ;
56
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
57
+ info = Ref {BlasInt} (),
58
+ check = false )
59
+ require_one_based_indexing (A)
60
+ check && chkfinite (A)
61
+ chkstride1 (A)
62
+ m, n = size (A)
63
+ lda = max (1 , stride (A, 2 ))
64
+ if isempty (ipiv)
65
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
66
+ end
67
+ ccall ((@blasfunc (dgetrf_), libblis), Cvoid,
68
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
69
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
70
+ m, n, A, lda, ipiv, info)
71
+ chkargsok (info[])
72
+ A, ipiv, info[], info # Error code is stored in LU factorization type
73
+ end
74
+
75
+ function getrf! (A:: AbstractMatrix{<:Float32} ;
76
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
77
+ info = Ref {BlasInt} (),
78
+ check = false )
79
+ require_one_based_indexing (A)
80
+ check && chkfinite (A)
81
+ chkstride1 (A)
82
+ m, n = size (A)
83
+ lda = max (1 , stride (A, 2 ))
84
+ if isempty (ipiv)
85
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
86
+ end
87
+ ccall ((@blasfunc (sgetrf_), libblis), Cvoid,
88
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
89
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
90
+ m, n, A, lda, ipiv, info)
91
+ chkargsok (info[])
92
+ A, ipiv, info[], info # Error code is stored in LU factorization type
93
+ end
94
+
95
+ function getrs! (trans:: AbstractChar ,
96
+ A:: AbstractMatrix{<:ComplexF64} ,
97
+ ipiv:: AbstractVector{BlasInt} ,
98
+ B:: AbstractVecOrMat{<:ComplexF64} ;
99
+ info = Ref {BlasInt} ())
100
+ require_one_based_indexing (A, ipiv, B)
101
+ LinearAlgebra. LAPACK. chktrans (trans)
102
+ chkstride1 (A, B, ipiv)
103
+ n = LinearAlgebra. checksquare (A)
104
+ if n != size (B, 1 )
105
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
106
+ end
107
+ if n != length (ipiv)
108
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
109
+ end
110
+ nrhs = size (B, 2 )
111
+ ccall ((" zgetrs_" , libblis), Cvoid,
112
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
113
+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
114
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
115
+ 1 )
116
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
117
+ B
118
+ end
119
+
120
+ function getrs! (trans:: AbstractChar ,
121
+ A:: AbstractMatrix{<:ComplexF32} ,
122
+ ipiv:: AbstractVector{BlasInt} ,
123
+ B:: AbstractVecOrMat{<:ComplexF32} ;
124
+ info = Ref {BlasInt} ())
125
+ require_one_based_indexing (A, ipiv, B)
126
+ LinearAlgebra. LAPACK. chktrans (trans)
127
+ chkstride1 (A, B, ipiv)
128
+ n = LinearAlgebra. checksquare (A)
129
+ if n != size (B, 1 )
130
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
131
+ end
132
+ if n != length (ipiv)
133
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
134
+ end
135
+ nrhs = size (B, 2 )
136
+ ccall ((" cgetrs_" , libblis), Cvoid,
137
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
138
+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
139
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
140
+ 1 )
141
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
142
+ B
143
+ end
144
+
145
+ function getrs! (trans:: AbstractChar ,
146
+ A:: AbstractMatrix{<:Float64} ,
147
+ ipiv:: AbstractVector{BlasInt} ,
148
+ B:: AbstractVecOrMat{<:Float64} ;
149
+ info = Ref {BlasInt} ())
150
+ require_one_based_indexing (A, ipiv, B)
151
+ LinearAlgebra. LAPACK. chktrans (trans)
152
+ chkstride1 (A, B, ipiv)
153
+ n = LinearAlgebra. checksquare (A)
154
+ if n != size (B, 1 )
155
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
156
+ end
157
+ if n != length (ipiv)
158
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
159
+ end
160
+ nrhs = size (B, 2 )
161
+ ccall ((" dgetrs_" , libblis), Cvoid,
162
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
163
+ Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
164
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
165
+ 1 )
166
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
167
+ B
168
+ end
169
+
170
+ function getrs! (trans:: AbstractChar ,
171
+ A:: AbstractMatrix{<:Float32} ,
172
+ ipiv:: AbstractVector{BlasInt} ,
173
+ B:: AbstractVecOrMat{<:Float32} ;
174
+ info = Ref {BlasInt} ())
175
+ require_one_based_indexing (A, ipiv, B)
176
+ LinearAlgebra. LAPACK. chktrans (trans)
177
+ chkstride1 (A, B, ipiv)
178
+ n = LinearAlgebra. checksquare (A)
179
+ if n != size (B, 1 )
180
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
181
+ end
182
+ if n != length (ipiv)
183
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
184
+ end
185
+ nrhs = size (B, 2 )
186
+ ccall ((" sgetrs_" , libblis), Cvoid,
187
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
188
+ Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
189
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
190
+ 1 )
191
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
192
+ B
193
+ end
194
+
195
+ default_alias_A (:: BLISLUFactorization , :: Any , :: Any ) = false
196
+ default_alias_b (:: BLISLUFactorization , :: Any , :: Any ) = false
197
+
198
+ const PREALLOCATED_BLIS_LU = begin
199
+ A = rand (0 , 0 )
200
+ luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
201
+ end
202
+
203
+ function LinearSolve. init_cacheval (alg:: BLISLUFactorization , A, b, u, Pl, Pr,
204
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
205
+ assumptions:: OperatorAssumptions )
206
+ PREALLOCATED_BLIS_LU
207
+ end
208
+
209
+ function LinearSolve. init_cacheval (alg:: BLISLUFactorization , A:: AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}} , b, u, Pl, Pr,
210
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
211
+ assumptions:: OperatorAssumptions )
212
+ A = rand (eltype (A), 0 , 0 )
213
+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
214
+ end
215
+
216
+ function SciMLBase. solve! (cache:: LinearCache , alg:: BLISLUFactorization ;
217
+ kwargs... )
218
+ A = cache. A
219
+ A = convert (AbstractMatrix, A)
220
+ if cache. isfresh
221
+ cacheval = @get_cacheval (cache, :BLISLUFactorization )
222
+ res = getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
223
+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
224
+ cache. cacheval = fact
225
+ cache. isfresh = false
226
+ end
227
+
228
+ y = ldiv! (cache. u, @get_cacheval (cache, :BLISLUFactorization )[1 ], cache. b)
229
+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
230
+
231
+ #=
232
+ A, info = @get_cacheval(cache, :BLISLUFactorization)
233
+ LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
234
+ m, n = size(A, 1), size(A, 2)
235
+ if m > n
236
+ Bc = copy(cache.b)
237
+ getrs!('N', A.factors, A.ipiv, Bc; info)
238
+ return copyto!(cache.u, 1, Bc, 1, n)
239
+ else
240
+ copyto!(cache.u, cache.b)
241
+ getrs!('N', A.factors, A.ipiv, cache.u; info)
242
+ end
243
+
244
+ SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
245
+ =#
246
+ end
247
+
248
+ end
0 commit comments