Skip to content

Commit 2fea1c2

Browse files
WIP: Wrap BLIS
Test case: ```julia using LinearSolve, blis_jll A = rand(4, 4) b = rand(4) prob = LinearProblem(A, b) sol = solve(prob,LinearSolve.BLISLUFactorization()) sol.u ``` throws: ```julia julia> sol = solve(prob,LinearSolve.BLISLUFactorization()) ERROR: TypeError: in ccall: first argument not a pointer or valid constant expression, expected Ptr, got a value of type Tuple{Symbol, Ptr{Nothing}} Stacktrace: [1] getrf!(A::Matrix{Float64}; ipiv::Vector{Int64}, info::Base.RefValue{Int64}, check::Bool) @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:67 [2] getrf! @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:55 [inlined] [3] #solve!#9 @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:222 [inlined] [4] solve! @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:216 [inlined] [5] #solve!#6 @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:209 [inlined] [6] solve! @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:208 [inlined] [7] #solve#5 @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:205 [inlined] [8] solve(::LinearProblem{…}, ::LinearSolve.BLISLUFactorization) @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:202 [9] top-level scope @ REPL[8]:1 Some type information was truncated. Use `show(err)` to see complete types. ```
1 parent a455e27 commit 2fea1c2

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3232
[weakdeps]
3333
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3434
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
35+
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
3536
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3637
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3738
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
@@ -44,6 +45,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4445

4546
[extensions]
4647
LinearSolveBandedMatricesExt = "BandedMatrices"
48+
LinearSolveBLISExt = "blis_jll"
4749
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4850
LinearSolveCUDAExt = "CUDA"
4951
LinearSolveEnzymeExt = "Enzyme"
@@ -58,6 +60,7 @@ LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
5860
[compat]
5961
ArrayInterface = "7.4.11"
6062
BandedMatrices = "1"
63+
blis_jll = "0.9.0"
6164
BlockDiagonals = "0.1"
6265
ConcreteStructs = "0.2"
6366
DocStringExtensions = "0.9"

ext/LinearSolveBLISExt.jl

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
module LinearSolveBLISExt
2+
3+
using Libdl
4+
using blis_jll
5+
using LinearAlgebra
6+
using LinearSolve
7+
8+
using LinearAlgebra: BlasInt, LU
9+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
10+
@blasfunc, chkargsok
11+
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase
12+
13+
const global libblis = dlopen(blis_jll.blis_path)
14+
15+
function getrf!(A::AbstractMatrix{<:ComplexF64};
16+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
17+
info = Ref{BlasInt}(),
18+
check = false)
19+
require_one_based_indexing(A)
20+
check && chkfinite(A)
21+
chkstride1(A)
22+
m, n = size(A)
23+
lda = max(1, stride(A, 2))
24+
if isempty(ipiv)
25+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
26+
end
27+
ccall((@blasfunc(zgetrf_), libblis), Cvoid,
28+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
29+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
30+
m, n, A, lda, ipiv, info)
31+
chkargsok(info[])
32+
A, ipiv, info[], info #Error code is stored in LU factorization type
33+
end
34+
35+
function getrf!(A::AbstractMatrix{<:ComplexF32};
36+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
37+
info = Ref{BlasInt}(),
38+
check = false)
39+
require_one_based_indexing(A)
40+
check && chkfinite(A)
41+
chkstride1(A)
42+
m, n = size(A)
43+
lda = max(1, stride(A, 2))
44+
if isempty(ipiv)
45+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
46+
end
47+
ccall((@blasfunc(cgetrf_), libblis), Cvoid,
48+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
49+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
50+
m, n, A, lda, ipiv, info)
51+
chkargsok(info[])
52+
A, ipiv, info[], info #Error code is stored in LU factorization type
53+
end
54+
55+
function getrf!(A::AbstractMatrix{<:Float64};
56+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
57+
info = Ref{BlasInt}(),
58+
check = false)
59+
require_one_based_indexing(A)
60+
check && chkfinite(A)
61+
chkstride1(A)
62+
m, n = size(A)
63+
lda = max(1, stride(A, 2))
64+
if isempty(ipiv)
65+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
66+
end
67+
ccall((@blasfunc(dgetrf_), libblis), Cvoid,
68+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
69+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
70+
m, n, A, lda, ipiv, info)
71+
chkargsok(info[])
72+
A, ipiv, info[], info #Error code is stored in LU factorization type
73+
end
74+
75+
function getrf!(A::AbstractMatrix{<:Float32};
76+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
77+
info = Ref{BlasInt}(),
78+
check = false)
79+
require_one_based_indexing(A)
80+
check && chkfinite(A)
81+
chkstride1(A)
82+
m, n = size(A)
83+
lda = max(1, stride(A, 2))
84+
if isempty(ipiv)
85+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
86+
end
87+
ccall((@blasfunc(sgetrf_), libblis), Cvoid,
88+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
89+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
90+
m, n, A, lda, ipiv, info)
91+
chkargsok(info[])
92+
A, ipiv, info[], info #Error code is stored in LU factorization type
93+
end
94+
95+
function getrs!(trans::AbstractChar,
96+
A::AbstractMatrix{<:ComplexF64},
97+
ipiv::AbstractVector{BlasInt},
98+
B::AbstractVecOrMat{<:ComplexF64};
99+
info = Ref{BlasInt}())
100+
require_one_based_indexing(A, ipiv, B)
101+
LinearAlgebra.LAPACK.chktrans(trans)
102+
chkstride1(A, B, ipiv)
103+
n = LinearAlgebra.checksquare(A)
104+
if n != size(B, 1)
105+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
106+
end
107+
if n != length(ipiv)
108+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
109+
end
110+
nrhs = size(B, 2)
111+
ccall(("zgetrs_", libblis), Cvoid,
112+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
113+
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
114+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
115+
1)
116+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
117+
B
118+
end
119+
120+
function getrs!(trans::AbstractChar,
121+
A::AbstractMatrix{<:ComplexF32},
122+
ipiv::AbstractVector{BlasInt},
123+
B::AbstractVecOrMat{<:ComplexF32};
124+
info = Ref{BlasInt}())
125+
require_one_based_indexing(A, ipiv, B)
126+
LinearAlgebra.LAPACK.chktrans(trans)
127+
chkstride1(A, B, ipiv)
128+
n = LinearAlgebra.checksquare(A)
129+
if n != size(B, 1)
130+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
131+
end
132+
if n != length(ipiv)
133+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
134+
end
135+
nrhs = size(B, 2)
136+
ccall(("cgetrs_", libblis), Cvoid,
137+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
138+
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
139+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
140+
1)
141+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
142+
B
143+
end
144+
145+
function getrs!(trans::AbstractChar,
146+
A::AbstractMatrix{<:Float64},
147+
ipiv::AbstractVector{BlasInt},
148+
B::AbstractVecOrMat{<:Float64};
149+
info = Ref{BlasInt}())
150+
require_one_based_indexing(A, ipiv, B)
151+
LinearAlgebra.LAPACK.chktrans(trans)
152+
chkstride1(A, B, ipiv)
153+
n = LinearAlgebra.checksquare(A)
154+
if n != size(B, 1)
155+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
156+
end
157+
if n != length(ipiv)
158+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
159+
end
160+
nrhs = size(B, 2)
161+
ccall(("dgetrs_", libblis), Cvoid,
162+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
163+
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
164+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
165+
1)
166+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
167+
B
168+
end
169+
170+
function getrs!(trans::AbstractChar,
171+
A::AbstractMatrix{<:Float32},
172+
ipiv::AbstractVector{BlasInt},
173+
B::AbstractVecOrMat{<:Float32};
174+
info = Ref{BlasInt}())
175+
require_one_based_indexing(A, ipiv, B)
176+
LinearAlgebra.LAPACK.chktrans(trans)
177+
chkstride1(A, B, ipiv)
178+
n = LinearAlgebra.checksquare(A)
179+
if n != size(B, 1)
180+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
181+
end
182+
if n != length(ipiv)
183+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
184+
end
185+
nrhs = size(B, 2)
186+
ccall(("sgetrs_", libblis), Cvoid,
187+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
188+
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
189+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
190+
1)
191+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
192+
B
193+
end
194+
195+
default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false
196+
default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false
197+
198+
const PREALLOCATED_BLIS_LU = begin
199+
A = rand(0, 0)
200+
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
201+
end
202+
203+
function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr,
204+
maxiters::Int, abstol, reltol, verbose::Bool,
205+
assumptions::OperatorAssumptions)
206+
PREALLOCATED_BLIS_LU
207+
end
208+
209+
function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
210+
maxiters::Int, abstol, reltol, verbose::Bool,
211+
assumptions::OperatorAssumptions)
212+
A = rand(eltype(A), 0, 0)
213+
ArrayInterface.lu_instance(A), Ref{BlasInt}()
214+
end
215+
216+
function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;
217+
kwargs...)
218+
A = cache.A
219+
A = convert(AbstractMatrix, A)
220+
if cache.isfresh
221+
cacheval = @get_cacheval(cache, :BLISLUFactorization)
222+
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
223+
fact = LU(res[1:3]...), res[4]
224+
cache.cacheval = fact
225+
cache.isfresh = false
226+
end
227+
228+
y = ldiv!(cache.u, @get_cacheval(cache, :BLISLUFactorization)[1], cache.b)
229+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
230+
231+
#=
232+
A, info = @get_cacheval(cache, :BLISLUFactorization)
233+
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
234+
m, n = size(A, 1), size(A, 2)
235+
if m > n
236+
Bc = copy(cache.b)
237+
getrs!('N', A.factors, A.ipiv, Bc; info)
238+
return copyto!(cache.u, 1, Bc, 1, n)
239+
else
240+
copyto!(cache.u, cache.b)
241+
getrs!('N', A.factors, A.ipiv, cache.u; info)
242+
end
243+
244+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
245+
=#
246+
end
247+
248+
end

src/extension_algs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,5 @@ A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pr
326326
to avoid allocations and automatically offloads to the GPU.
327327
"""
328328
struct MetalLUFactorization <: AbstractFactorization end
329+
330+
struct BLISLUFactorization <: AbstractFactorization end

0 commit comments

Comments
 (0)