Skip to content

Commit bf0364b

Browse files
authored
Faster dot product for sparse matrices and dense vectors (#39889)
1 parent 6cea0d2 commit bf0364b

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

+8
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,14 @@ end
697697
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
698698
dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)
699699

700+
dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
701+
function dot(D::Diagonal, B::AbstractMatrix)
702+
size(D) == size(B) || throw(DimensionMismatch("Matrix sizes $(size(D)) and $(size(B)) differ"))
703+
return dot(D.diag, view(B, diagind(B)))
704+
end
705+
706+
dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A))
707+
700708
function _mapreduce_prod(f, x, D::Diagonal, y)
701709
if isempty(x) && isempty(D) && isempty(y)
702710
return zero(Base.promote_op(f, eltype(x), eltype(D), eltype(y)))

stdlib/LinearAlgebra/test/diagonal.jl

+9
Original file line numberDiff line numberDiff line change
@@ -735,4 +735,13 @@ end
735735
@test dot(zeros(Int32, 0), Diagonal(zeros(Int, 0)), zeros(Int16, 0)) === 0
736736
end
737737

738+
@testset "Inner product" begin
739+
A = Diagonal(rand(10) .+ im)
740+
B = Diagonal(rand(10) .+ im)
741+
@test dot(A, B) dot(Matrix(A), B)
742+
@test dot(A, B) dot(A, Matrix(B))
743+
@test dot(A, B) dot(Matrix(A), Matrix(B))
744+
@test dot(A, B) conj(dot(B, A))
745+
end
746+
738747
end # module TestDiagonal

stdlib/SparseArrays/src/linalg.jl

+36
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,42 @@ function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector)
376376
r
377377
end
378378

379+
const WrapperMatrixTypes{T,MT} = Union{
380+
SubArray{T,2,MT},
381+
Adjoint{T,MT},
382+
Transpose{T,MT},
383+
AbstractTriangular{T,MT},
384+
UpperHessenberg{T,MT},
385+
Symmetric{T,MT},
386+
Hermitian{T,MT},
387+
}
388+
389+
function dot(A::MA, B::AbstractSparseMatrixCSC{TB}) where {MA<:Union{DenseMatrixUnion,WrapperMatrixTypes{<:Any,Union{DenseMatrixUnion,AbstractSparseMatrix}}},TB}
390+
T = promote_type(eltype(A), TB)
391+
(m, n) = size(A)
392+
if (m, n) != size(B)
393+
throw(DimensionMismatch())
394+
end
395+
s = zero(T)
396+
if m * n == 0
397+
return s
398+
end
399+
rows = rowvals(B)
400+
vals = nonzeros(B)
401+
@inbounds for j in 1:n
402+
for ridx in nzrange(B, j)
403+
i = rows[ridx]
404+
v = vals[ridx]
405+
s += dot(A[i,j], v)
406+
end
407+
end
408+
return s
409+
end
410+
411+
function dot(A::AbstractSparseMatrixCSC{TA}, B::MB) where {TA,MB<:Union{DenseMatrixUnion,WrapperMatrixTypes{<:Any,Union{DenseMatrixUnion,AbstractSparseMatrix}}}}
412+
return conj(dot(B, A))
413+
end
414+
379415
## triangular sparse handling
380416

381417
possible_adjoint(adj::Bool, a::Real) = a

stdlib/SparseArrays/test/sparse.jl

+31-1
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,42 @@ end
472472
end
473473

474474
@testset "sparse Frobenius dot/inner product" begin
475+
full_view = M -> view(M, :, :)
475476
for i = 1:5
476477
A = sprand(ComplexF64,10,15,0.4)
477478
B = sprand(ComplexF64,10,15,0.5)
478-
@test dot(A,B) dot(Matrix(A),Matrix(B))
479+
C = rand(10,15) .> 0.3
480+
@test dot(A,B) dot(Matrix(A), Matrix(B))
481+
@test dot(A,B) dot(A, Matrix(B))
482+
@test dot(A,B) dot(Matrix(A), B)
483+
@test dot(A,C) dot(Matrix(A), C)
484+
@test dot(C,A) dot(C, Matrix(A))
485+
# square matrices required by most linear algebra wrappers
486+
SA = A * A'
487+
SB = B * B'
488+
SC = C * C'
489+
for W in (full_view, LowerTriangular, UpperTriangular, UpperHessenberg, Symmetric, Hermitian)
490+
WA = W(Matrix(SA))
491+
WB = W(Matrix(SB))
492+
WC = W(Matrix(SC))
493+
@test dot(WA,SB) dot(WA, Matrix(SB))
494+
@test dot(SA,WB) dot(Matrix(SA), WB)
495+
@test dot(SA,WC) dot(Matrix(SA), WC)
496+
end
497+
for W in (transpose, adjoint)
498+
WA = W(Matrix(A))
499+
WB = W(Matrix(B))
500+
WC = W(Matrix(C))
501+
TA = copy(W(A))
502+
TB = copy(W(B))
503+
@test dot(WA,TB) dot(WA, Matrix(TB))
504+
@test dot(TA,WB) dot(Matrix(TA), WB)
505+
@test dot(TA,WC) dot(Matrix(TA), WC)
506+
end
479507
end
480508
@test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2))
509+
@test_throws DimensionMismatch dot(rand(5,5),sprand(5,6,0.2))
510+
@test_throws DimensionMismatch dot(sprand(5,5,0.2),rand(5,6))
481511
end
482512

483513
@testset "generalized dot product" begin

0 commit comments

Comments
 (0)