Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 16c0080

Browse files
Merge #692
692: Fix sparse mul! r=amontoison a=amontoison close #629 #630 , #637 @haampie Co-authored-by: Alexis Montoison <[email protected]>
2 parents ee16e77 + 2ec414f commit 16c0080

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

src/sparse/interfaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:Blas
1111

1212
LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O')
1313
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
14-
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(transA),B,zero(T),C,'O')
14+
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
1515
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O')
1616
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
1717
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')

src/sparse/wrappers.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,17 +567,21 @@ for (fname,elty) in ((:cusparseScsrmv, :Float32),
567567
ctransa = 'T'
568568
end
569569
cutransa = cusparseop(ctransa)
570-
cuind = cusparseindex(index)
571-
cudesc = getDescr(A,index)
572570
n,m = Mat.dims
573571
if ctransa == 'N'
574-
chkmvdims(X,n,Y,m)
572+
chkmvdims(X, n, Y, m)
575573
end
576574
if ctransa == 'T' || ctransa == 'C'
577-
chkmvdims(X,m,Y,n)
575+
chkmvdims(X, m, Y, n)
576+
end
577+
cudesc = getDescr(A,index)
578+
nzVal = Mat.nzVal
579+
if transa == 'C' && $elty <: Complex
580+
nzVal = conj(Mat.nzVal)
578581
end
579-
$fname(handle(), cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc),
580-
Mat.nzVal, Mat.colPtr, Mat.rowVal, X, [beta], Y)
582+
$fname(handle(),
583+
cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc), nzVal,
584+
Mat.colPtr, Mat.rowVal, X, [beta], Y)
581585
Y
582586
end
583587
end

test/sparse.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,4 +2006,50 @@ end
20062006
end
20072007
end
20082008

2009+
@testset "mul!" begin
2010+
for elty in [Float32,Float64,ComplexF32,ComplexF64]
2011+
A = sparse(rand(elty,m,m))
2012+
x = rand(elty,m)
2013+
y = rand(elty,m)
2014+
@testset "csr -- $elty" begin
2015+
d_x = CuArray(x)
2016+
d_y = CuArray(y)
2017+
d_A = CuSparseMatrixCSR(A)
2018+
d_Aᵀ = transpose(d_A)
2019+
d_Aᴴ = adjoint(d_A)
2020+
CUSPARSE.mul!(d_y, d_A, d_x)
2021+
h_y = collect(d_y)
2022+
z = A * x
2023+
@test z h_y
2024+
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
2025+
h_y = collect(d_y)
2026+
z = transpose(A) * x
2027+
@test z h_y
2028+
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
2029+
h_y = collect(d_y)
2030+
z = adjoint(A) * x
2031+
@test z h_y
2032+
end
2033+
@testset "csc -- $elty" begin
2034+
d_x = CuArray(x)
2035+
d_y = CuArray(y)
2036+
d_A = CuSparseMatrixCSC(A)
2037+
d_Aᵀ = transpose(d_A)
2038+
d_Aᴴ = adjoint(d_A)
2039+
CUSPARSE.mul!(d_y, d_A, d_x)
2040+
h_y = collect(d_y)
2041+
z = A * x
2042+
@test z h_y
2043+
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
2044+
h_y = collect(d_y)
2045+
z = transpose(A) * x
2046+
@test z h_y
2047+
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
2048+
h_y = collect(d_y)
2049+
z = adjoint(A) * x
2050+
@test z h_y
2051+
end
2052+
end
2053+
end
2054+
20092055
end

0 commit comments

Comments
 (0)