Skip to content

propagate() is 20x slower than built-in sparse matmul #106

@learning-chip

Description

@learning-chip

With the well-known graph-matrix duality (see GraphBLAS intro, Fig. 1), simple graph message passing kernels are equivalent to sparse matrix-vector multiplication (SpMV) or sparse matrix-matrix multiplication (SpMM). However, I notice that propagate() is more than 20x slower than the built-in A * B for an equivalent operation. I did the same test with DGL, and did not observe such drastic slow down.

To reproduce

using SparseArrays
using GraphNeuralNetworks
using BenchmarkTools
import Random: seed!

n = 1024
seed!(0)
A = sprand(n, n, 0.01)
b = rand(1, n)
B = rand(100, n)

g = GNNGraph(
    A,
    ndata=(; b=b, B=B),
    edata=(; A=reshape(A.nzval, 1, :)),
    graph_type=:coo  # changing to :sparse has little effect on performance
)

function spmv(g)
    propagate(
        (xi, xj, e) -> e .* xj ,  # same as e_mul_xj
        g, +; xj=g.ndata.b, e=g.edata.A
        )
end

function spmm(g)
    propagate(
        (xi, xj, e) -> e .* xj ,  # same as e_mul_xj
        g, +; xj=g.ndata.B, e=g.edata.A
        )
end

isequal(spmv(g),  b * A)  # true
@btime spmv(g)  # ~5 ms
@btime b * A  # ~32 us

isequal(spmm(g), B * A)  # true
@btime spmm(g)  # ~9 ms
@btime B * A  # ~400 us

Such performance gap can't be explained by storing the sparse matrix in COO (GNN libraries' default) vs CSR (SciPy default) vs CSC (Julia default). In the code below, changing scipy matrix format has minor effect on speed. Also, the speed of DGL and SciPy are similar.

Compare with DGL and SciPy

import numpy as np
import scipy.sparse as sp
import torch

import dgl
import dgl.function as fn

n = 1024

np.random.seed(0)
A = sp.random(n, n, density=0.01, format='csc')  # changing format to `coo` or `csr` affects performance, but not much
b = np.random.rand(1, n)
B = np.random.rand(100, n)

g = dgl.from_scipy(A)
g.edata['A'] = torch.tensor(A.data[:, np.newaxis])
g.ndata['b'] = torch.tensor(b.T)
g.ndata['B'] = torch.tensor(B.T)

def spmv(g):
    with g.local_scope():
        g.update_all(fn.e_mul_u('A', 'b', 'm'), fn.sum('m', 'bA'))
        return g.ndata['bA']
    
def spmm(g):
    with g.local_scope():
        g.update_all(fn.e_mul_u('A', 'B', 'M'), fn.sum('M', 'BA'))
        return g.ndata['BA']

np.array_equal(spmv(g).numpy().T, b @ A)  # True
%timeit spmv(g)  # ~200 us
%timeit b @ A  # ~70 us

np.array_equal(spmm(g).numpy().T, B @ A)  # True
%timeit spmm(g)  # ~900 us
%timeit B @ A  # ~900 us

Effect of fusion

DGL's update_all fuses the message and reduction kernels. To mimic the two-stage propagate, and see if fusion causes such performance difference:

def spmv_twostage(g):
    with g.local_scope():
        g.apply_edges(fn.e_mul_u('A', 'b', 'm'))
        g.update_all(
            fn.copy_e('m', 'm'),
            fn.sum('m', 'bA')
        )
        return g.ndata['bA']

%timeit spmv_twostage(g)  # ~240 us; just 20% slower

The unfused version is only slightly slower. It cannot explain the 5ms vs 200us performance gap.

There must be other causes of inefficiency. I'd like to figure it out and bring the performance at least close to DGL. (I use DGL a lot, but there are certain projects that favor all-Julia implementation, and your package seems a good option with a clean syntax🙂 )

Package version

Jl:

  • GraphNeuralNetworks.jl 0.3.8
  • Julia 1.7

Py:

  • DGL 0.7.1
  • PyTorch 1.9.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions