Skip to content

Commit 8d2363a

Browse files
test with sparse cuda arrays
1 parent 195bb6c commit 8d2363a

File tree

5 files changed

+41
-12
lines changed

5 files changed

+41
-12
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module GNNGraphs
33
using SparseArrays
44
using Functors: @functor
55
using CUDA
6+
using CUDA.CUSPARSE
67
import Graphs
78
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
89
import Flux

src/GNNGraphs/convert.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,39 @@ function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
137137
s, t, eweight = coo
138138
eweight = isnothing(eweight) ? fill!(similar(s, T), 1) : eweight
139139
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
140-
A = sparse(s, t, eweight, num_nodes, num_nodes)
140+
A = _sparse(s, t, eweight, num_nodes, num_nodes)
141141
num_edges = length(s)
142142
return A, num_nodes, num_edges
143143
end
144144

145+
_sparse(s, t, eweight, n, m) = sparse(s, t, eweight, n, m)
146+
147+
function _sparse(I::CuVector, J::CuVector, V::CuVector, m, n)
148+
spcoo = CuSparseMatrixCOO{Float32, Int32}(Int32.(I), Int32.(J), Float32.(V), (m, n))
149+
return CuSparseMatrixCSR(spcoo)
150+
end
151+
152+
# function _sparse(I::CuVector, J::CuVector, V::CuVector, m, n; fmt=:csr)
153+
# # Tv = Int32
154+
# spcoo = CuSparseMatrixCOO{Float32, Int32}(Int32.(I), Int32.(J), Float32.(V), (m, n))
155+
# if fmt == :csc
156+
# return CuSparseMatrixCSC(spcoo)
157+
# elseif fmt == :csr
158+
# return CuSparseMatrixCSR(spcoo)
159+
# elseif fmt == :coo
160+
# return spcoo
161+
# else
162+
# error("Format :$fmt not available, use :csc, :csr, or :coo.")
163+
# end
164+
# end
165+
166+
167+
# Workaround for https://github.com/JuliaGPU/CUDA.jl/issues/1113#issuecomment-955759875
168+
function Base.:*(A::CuMatrix, B::CuSparseMatrixCSR)
169+
@assert size(A, 2) == size(B, 1)
170+
return CuMatrix((B' * A')')
171+
end
172+
145173

146174
@non_differentiable to_coo(x...)
147175
@non_differentiable to_dense(x...)

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, grap
192192
ndata, edata, gdata)
193193
end
194194

195-
function Base.show(io::IO, g::GNNGraph)
196-
println(io, "GNNGraph:
195+
function Base.show(io::IO, g::GNNGraph{T}) where T
196+
println(io, "GNNGraph{$T}:
197197
num_nodes = $(g.num_nodes)
198198
num_edges = $(g.num_edges)
199199
num_graphs = $(g.num_graphs)")

src/GNNGraphs/query.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ function adjacency_list(g::GNNGraph; dir=:out)
7474
end
7575

7676
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out)
77-
if g.graph[1] isa CuVector
78-
# TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152
79-
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
80-
else
77+
# if g.graph[1] isa CuVector
78+
# # TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152
79+
# A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
80+
# else
8181
A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes)
82-
end
82+
# end
8383
@assert size(A) == (n, n)
8484
return dir == :out ? A : A'
8585
end

src/msgpass.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractM
152152
return xj * A
153153
end
154154

155-
## avoid the fast path on gpu until we have better cuda support
156-
function propagate(::typeof(copyxj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
157-
propagate((xi,xj,e)->copyxj(xi,xj,e), g, +, xi, xj, e)
158-
end
155+
# ## avoid the fast path on gpu until we have better cuda support
156+
# function propagate(::typeof(copyxj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
157+
# propagate((xi,xj,e) -> copyxj(xi,xj,e), g, +, xi, xj, e)
158+
# end
159159

160160
# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
161161
# A = adjacency_matrix(g)

0 commit comments

Comments
 (0)