Skip to content

Commit 2b1883f

Browse files
Merge pull request #498 from jd-foster/blis-ext-lapack_jll
Enable "WIP: Wrap BLIS" with reference LAPACK
2 parents afcc28e + 3c2ffe0 commit 2b1883f

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3939
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
4040
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4141
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
42+
LAPACK_jll = "51474c39-65e3-53ba-86ba-03b1b862ec14"
4243
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4344
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4445
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4546

4647
[extensions]
4748
LinearSolveBandedMatricesExt = "BandedMatrices"
48-
LinearSolveBLISExt = "blis_jll"
49+
LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"]
4950
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
5051
LinearSolveCUDAExt = "CUDA"
5152
LinearSolveEnzymeExt = "Enzyme"

ext/LinearSolveBLISExt.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@ module LinearSolveBLISExt
22

33
using Libdl
44
using blis_jll
5+
using LAPACK_jll
56
using LinearAlgebra
67
using LinearSolve
78

8-
using LinearAlgebra: BlasInt, LU
9+
using LinearAlgebra: libblastrampoline, BlasInt, LU
910
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
1011
@blasfunc, chkargsok
1112
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase
1213

1314
const global libblis = blis_jll.blis
15+
const global liblapack = libblastrampoline
16+
17+
BLAS.lbt_forward(libblis; clear=true, verbose=true, suffix_hint="64_")
18+
BLAS.lbt_forward(LAPACK_jll.liblapack_path; suffix_hint="64_", verbose=true)
1419

1520
function getrf!(A::AbstractMatrix{<:ComplexF64};
1621
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
@@ -24,7 +29,7 @@ function getrf!(A::AbstractMatrix{<:ComplexF64};
2429
if isempty(ipiv)
2530
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
2631
end
27-
ccall((@blasfunc(zgetrf_), libblis), Cvoid,
32+
ccall((@blasfunc(zgetrf_), liblapack), Cvoid,
2833
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
2934
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
3035
m, n, A, lda, ipiv, info)
@@ -44,7 +49,7 @@ function getrf!(A::AbstractMatrix{<:ComplexF32};
4449
if isempty(ipiv)
4550
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
4651
end
47-
ccall((@blasfunc(cgetrf_), libblis), Cvoid,
52+
ccall((@blasfunc(cgetrf_), liblapack), Cvoid,
4853
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
4954
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
5055
m, n, A, lda, ipiv, info)
@@ -64,7 +69,7 @@ function getrf!(A::AbstractMatrix{<:Float64};
6469
if isempty(ipiv)
6570
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
6671
end
67-
ccall((@blasfunc(dgetrf_), libblis), Cvoid,
72+
ccall((@blasfunc(dgetrf_), liblapack), Cvoid,
6873
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
6974
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
7075
m, n, A, lda, ipiv, info)
@@ -84,7 +89,7 @@ function getrf!(A::AbstractMatrix{<:Float32};
8489
if isempty(ipiv)
8590
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
8691
end
87-
ccall((@blasfunc(sgetrf_), libblis), Cvoid,
92+
ccall((@blasfunc(sgetrf_), liblapack), Cvoid,
8893
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
8994
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
9095
m, n, A, lda, ipiv, info)
@@ -108,7 +113,7 @@ function getrs!(trans::AbstractChar,
108113
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
109114
end
110115
nrhs = size(B, 2)
111-
ccall(("zgetrs_", libblis), Cvoid,
116+
ccall(("zgetrs_", liblapack), Cvoid,
112117
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
113118
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
114119
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
@@ -133,7 +138,7 @@ function getrs!(trans::AbstractChar,
133138
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
134139
end
135140
nrhs = size(B, 2)
136-
ccall(("cgetrs_", libblis), Cvoid,
141+
ccall(("cgetrs_", liblapack), Cvoid,
137142
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
138143
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
139144
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
@@ -158,7 +163,7 @@ function getrs!(trans::AbstractChar,
158163
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
159164
end
160165
nrhs = size(B, 2)
161-
ccall(("dgetrs_", libblis), Cvoid,
166+
ccall(("dgetrs_", liblapack), Cvoid,
162167
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
163168
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
164169
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
@@ -183,7 +188,7 @@ function getrs!(trans::AbstractChar,
183188
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
184189
end
185190
nrhs = size(B, 2)
186-
ccall(("sgetrs_", libblis), Cvoid,
191+
ccall(("sgetrs_", liblapack), Cvoid,
187192
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
188193
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
189194
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,

0 commit comments

Comments
 (0)