Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 991ecf8

Browse files
Merge #692
692: Fix sparse mul! r=maleadt a=amontoison close #629 #630 , #637 @haampie Co-authored-by: Alexis Montoison <[email protected]>
2 parents c38da71 + 3e39435 commit 991ecf8

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/sparse/interfaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:Blas
1111

1212
LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O')
1313
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
14-
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(transA),B,zero(T),C,'O')
14+
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
1515
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O')
1616
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
1717
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')

test/sparse.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,4 +1970,50 @@ end
19701970
end
19711971
end
19721972

1973+
@testset "mul!" begin
1974+
for elty in [Float32,Float64,ComplexF32,ComplexF64]
1975+
A = sparse(rand(elty,m,m))
1976+
x = rand(elty,m)
1977+
y = rand(elty,m)
1978+
@testset "csr" begin
1979+
d_x = CuArray(x)
1980+
d_y = CuArray(y)
1981+
d_A = CuSparseMatrixCSR(A)
1982+
d_Aᵀ = transpose(d_A)
1983+
d_Aᴴ = adjoint(d_A)
1984+
CUSPARSE.mul!(d_y, d_A, d_x)
1985+
h_y = collect(d_y)
1986+
z = A * x
1987+
@test z h_y
1988+
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
1989+
h_y = collect(d_y)
1990+
z = transpose(A) * x
1991+
@test z h_y
1992+
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
1993+
h_y = collect(d_y)
1994+
z = adjoint(A) * x
1995+
@test z h_y
1996+
end
1997+
@testset "csc" begin
1998+
d_x = CuArray(x)
1999+
d_y = CuArray(y)
2000+
d_A = CuSparseMatrixCSC(A)
2001+
d_Aᵀ = transpose(d_A)
2002+
d_Aᴴ = adjoint(d_A)
2003+
CUSPARSE.mul!(d_y, d_A, d_x)
2004+
h_y = collect(d_y)
2005+
z = A * x
2006+
@test z h_y
2007+
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
2008+
h_y = collect(d_y)
2009+
z = transpose(A) * x
2010+
@test z h_y
2011+
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
2012+
h_y = collect(d_y)
2013+
z = adjoint(A) * x
2014+
@test z h_y
2015+
end
2016+
end
2017+
end
2018+
19732019
end

0 commit comments

Comments
 (0)