Skip to content

Commit 1f94933

Browse files
Merge pull request #44 from Wimmerer/KLU
Add KLU factorization
2 parents b7f0a96 + b460aa5 commit 1f94933

File tree

4 files changed

+59
-2
lines changed

4 files changed

+59
-2
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.4"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
9+
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
910
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1011
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -21,6 +22,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2122
[compat]
2223
ArrayInterface = "3"
2324
IterativeSolvers = "0.9.2"
25+
KLU = "0.2.1"
2426
Krylov = "0.7.9"
2527
KrylovKit = "0.5"
2628
RecursiveFactorization = "0.2"

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using Setfield
1212
using UnPack
1313
using Requires
1414
using SuiteSparse
15+
using KLU
1516
# wrap
1617
import Krylov
1718
import KrylovKit # TODO
@@ -44,7 +45,7 @@ function __init__()
4445
end
4546

4647
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
47-
RFLUFactorizaation, UMFPACKFactorization
48+
RFLUFactorizaation, UMFPACKFactorization, KLUFactorization
4849
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
4950
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5051
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES

src/factorization.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,42 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization)
7171
SciMLBase.build_linear_solution(alg,y,nothing)
7272
end
7373

74+
Base.@kwdef struct KLUFactorization <: AbstractFactorization
75+
reuse_symbolic::Bool = true
76+
end
77+
78+
function init_cacheval(::KLUFactorization, A, b, u)
79+
if A isa AbstractDiffEqOperator
80+
A = A.A
81+
end
82+
if A isa SparseMatrixCSC
83+
return klu(A)
84+
else
85+
error("KLU is not defined for $(typeof(A))")
86+
end
87+
end
88+
89+
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization)
90+
A = cache.A
91+
if A isa AbstractDiffEqOperator
92+
A = A.A
93+
end
94+
if cache.isfresh
95+
if cache.cacheval !== nothing && alg.reuse_symbolic
96+
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
97+
# This won't recompute if it does.
98+
KLU.klu_analyze!(cache.cacheval)
99+
fact = klu!(cache.cacheval, A)
100+
else
101+
fact = init_cacheval(alg, A, cache.b, cache.u)
102+
end
103+
cache = set_cacheval(cache, fact)
104+
end
105+
106+
y = ldiv!(cache.u, cache.cacheval, cache.b)
107+
SciMLBase.build_linear_solution(alg,y,nothing)
108+
end
109+
74110
## QRFactorization
75111

76112
struct QRFactorization{P} <: AbstractFactorization

test/runtests.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,30 @@ end
7474
test_interface(UMFPACKFactorization(), prob1, prob2)
7575

7676
# Test that refactoring wrong throws.
77-
cache = SciMLBase.init(prob1,UMFPACKFactorization(reuse_symbolic=true); cache_kwargs...) # initialize cache
77+
cache = SciMLBase.init(prob1,UMFPACKFactorization(); cache_kwargs...) # initialize cache
7878
y = solve(cache)
7979
cache = LinearSolve.set_A(cache,sprand(n, n, 0.8))
8080
@test_throws ArgumentError solve(cache)
8181
end
8282

83+
@testset "KLU Factorization" begin
84+
A1 = A/1; b1 = rand(n); x1 = zero(b)
85+
A2 = A/2; b2 = rand(n); x2 = zero(b)
86+
87+
prob1 = LinearProblem(sparse(A1), b1; u0=x1)
88+
prob2 = LinearProblem(sparse(A2), b2; u0=x2)
89+
test_interface(KLUFactorization(), prob1, prob2)
90+
91+
# Test that refactoring wrong throws.
92+
cache = SciMLBase.init(prob1,KLUFactorization(); cache_kwargs...) # initialize cache
93+
y = solve(cache)
94+
X = copy(A1)
95+
X[8,8] = 0.0
96+
X[7,8] = 1.0
97+
cache = LinearSolve.set_A(cache,sparse(X))
98+
@test_throws ArgumentError solve(cache)
99+
end
100+
83101
@testset "Concrete Factorizations" begin
84102
for alg in (
85103
LUFactorization(),

0 commit comments

Comments
 (0)