Skip to content

mul/ewise rules for basic arithmetic semiring #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@ version = "0.4.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -15,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
SSGraphBLAS_jll = "5.1.2"
CEnum = "0.4"
ContextVariablesX = "0.1"
MacroTools = "0.5"
SSGraphBLAS_jll = "5.1"
julia = "1.6"
CEnum = "0.4.1"
ContextVariablesX = "0.1.1"
MacroTools = "0.5.6"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.7"
FiniteDifferences = "0.12"
9 changes: 7 additions & 2 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
@@ -87,9 +87,14 @@ include("operations/kronecker.jl")
include("print.jl")
include("import.jl")
include("export.jl")

#EXPERIMENTAL
include("options.jl")
#EXPERIMENTAL
include("chainrules/chainruleutils.jl")
include("chainrules/mulrules.jl")
include("chainrules/ewiserules.jl")
include("chainrules/maprules.jl")
include("chainrules/reducerules.jl")
include("chainrules/selectrules.jl")
#include("random.jl")
include("misc.jl")
export libgb
46 changes: 46 additions & 0 deletions src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import FiniteDifferences
import LinearAlgebra
import ChainRulesCore: frule, rrule
using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

#Required for ChainRulesTestUtils
function FiniteDifferences.to_vec(M::GBMatrix)
I, J, X = findnz(M)
function backtomat(xvec)
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
end
return X, backtomat
end

function FiniteDifferences.to_vec(v::GBVector)
i, x = findnz(v)
function backtovec(xvec)
return GBVector(i, xvec; nrows=size(v, 1))
end
return x, backtovec
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBMatrix{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, J, _ = findnz(x)
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBVector{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, _ = findnz(x)
return GBVector(I, v; nrows = size(x, 1))
end

FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
# LinearAlgebra.norm freaks over the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
71 changes: 71 additions & 0 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#emul TIMES
function frule(
(_, ΔA, ΔB, _),
::typeof(emul),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.TIMES)
)
Ω = emul(A, B, BinaryOps.TIMES)
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES)
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, emulpb
end

############
# eadd rules
############

# PLUS
######

function frule(
(_, ΔA, ΔB, _),
::typeof(eadd),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.PLUS)
)
Ω = eadd(A, B, BinaryOps.PLUS)
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
function pluspullback(ΔΩ)
return (
NoTangent(),
mask(ΔΩ, A; structural = true),
mask(ΔΩ, B; structural = true),
NoTangent()
)
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end

# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS)
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, eaddpb
end
17 changes: 17 additions & 0 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
# than AbstractOp.
#function rrule(map, f, xs)
# # Rather than 3 maps really want 1 multimap
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
# ys = map(first, ys_and_pullbacks)
# pullbacks = map(last, ys_and_pullbacks)
# function map_pullback(dys)
# _call(f, x) = f(x)
# dfs_and_dxs = map(_call, pullbacks, dys)
# # but in your case you know it will be NoTangent() so can skip
# df = sum(first, dfs_and_dxs)
# dxs = map(last, dfs_and_dxs)
# return NoTangent(), df, dxs
# end
# return ys, map_pullback
#end
51 changes: 51 additions & 0 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.

# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.

function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B'; mask=A)
∂B = mul(A', ΔΩ; mask=B)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end


function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES)
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
return Ω, pullback
end
Empty file added src/chainrules/reducerules.jl
Empty file.
Empty file added src/chainrules/selectrules.jl
Empty file.
22 changes: 11 additions & 11 deletions src/lib/LibGraphBLAS.jl
Original file line number Diff line number Diff line change
@@ -27,27 +27,27 @@ macro wraperror(code)
elseif info == GrB_NO_VALUE
return nothing
else
if info == GrB_UNINITIALIZED_OBJECT
if info == GrB_UNINITIALIZED_OBJECT
throw(UninitializedObjectError)
elseif info == GrB_INVALID_OBJECT
elseif info == GrB_INVALID_OBJECT
throw(InvalidObjectError)
elseif info == GrB_NULL_POINTER
elseif info == GrB_NULL_POINTER
throw(NullPointerError)
elseif info == GrB_INVALID_VALUE
elseif info == GrB_INVALID_VALUE
throw(InvalidValueError)
elseif info == GrB_INVALID_INDEX
elseif info == GrB_INVALID_INDEX
throw(InvalidIndexError)
elseif info == GrB_DOMAIN_MISMATCH
elseif info == GrB_DOMAIN_MISMATCH
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
elseif info == GrB_DIMENSION_MISMATCH
throw(DimensionMismatch())
elseif info == GrB_OUTPUT_NOT_EMPTY
elseif info == GrB_OUTPUT_NOT_EMPTY
throw(OutputNotEmptyError)
elseif info == GrB_OUT_OF_MEMORY
elseif info == GrB_OUT_OF_MEMORY
throw(OutOfMemoryError())
elseif info == GrB_INSUFFICIENT_SPACE
elseif info == GrB_INSUFFICIENT_SPACE
throw(InsufficientSpaceError)
elseif info == GrB_INDEX_OUT_OF_BOUNDS
elseif info == GrB_INDEX_OUT_OF_BOUNDS
throw(BoundsError())
elseif info == GrB_PANIC
throw(PANIC)
@@ -843,7 +843,7 @@ for T ∈ valid_vec
nvals = GrB_Vector_nvals(v)
I = Vector{GrB_Index}(undef, nvals)
X = Vector{$type}(undef, nvals)
nvals = Ref{GrB_Index}()
nvals = Ref{GrB_Index}(nvals)
$func(I, X, nvals, v)
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
return I .+ 1, X
9 changes: 5 additions & 4 deletions src/matrix.jl
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct
to `|` for booleans and `+` for nonbooleans.
"""
function GBMatrix(
I::Vector, J::Vector, X::Vector{T};
I::AbstractVector, J::AbstractVector, X::AbstractVector{T};
dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J)
) where {T}
A = GBMatrix{T}(nrows, ncols)
@@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x.
The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for
each index.
"""
function GBMatrix(I::Vector, J::Vector, x::T;
function GBMatrix(I::AbstractVector, J::AbstractVector, x::T;
nrows = maximum(I), ncols = maximum(J)) where {T}
A = GBMatrix{T}(nrows, ncols)
build(A, I, J, x)
return A
end

function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T}
function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T}
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements"))
length(I) == length(J) || DimensionMismatch("I, J and X must have the same length")
x = GBScalar(x)
@@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
gxbprint(io, A)
end

SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]


# Indexing functions
####################
8 changes: 7 additions & 1 deletion src/operations/ewise.jl
Original file line number Diff line number Diff line change
@@ -61,7 +61,6 @@ function emul!(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)

size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
@@ -275,6 +274,13 @@ function eadd(
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
end

function Base.:-(A::GBArray, B::GBArray)
eadd(A, B, BinaryOps.MINUS)
end
#Elementwise Broadcasts
#######################

1 change: 0 additions & 1 deletion src/operations/mul.jl
Original file line number Diff line number Diff line change
@@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
return w
end


"""
mul(A::GBArray, B::GBArray; kwargs...)::GBArray

34 changes: 34 additions & 0 deletions src/operations/transpose.jl
Original file line number Diff line number Diff line change
@@ -64,13 +64,47 @@ function Base.copy!(
return gbtranspose!(C, A.parent; mask, accum, desc)
end

"""
mask!(C::GBArray, A::GBArray, mask::GBArray)

Apply a mask to matrix `A`, storing the results in C.

"""
function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false)
desc = Descriptors.T0
structural && (desc = desc + Descriptors.S)
complement && (desc = desc + Descriptors.C)
gbtranspose!(C, A; mask, desc)
return C
end

"""
mask(A::GBArray, mask::GBArray)

Apply a mask to matrix `A`.
"""
function mask(A::GBArray, mask::GBArray; structural = false, complement = false)
return mask!(similar(A), A, mask; structural, complement)
end

function Base.copy(
A::LinearAlgebra.Transpose{<:Any, <:GBMatrix};
mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL
)
return gbtranspose(A.parent; mask, accum, desc)
end

function Base.copy(v::LinearAlgebra.Transpose{<:Any, <:GBVector})
A = GBMatrix{eltype(v)}(size(v, 1), size(v, 2))
nz = findnz(v.parent)
for i ∈ 1:length(nz[1])
println(i)
println(nz[1][i], ": ", nz[2][i])
A[1, nz[1][i]] = nz[2][i]
end
return A
end

function _handletranspose(
A::GBArray,
desc::Union{Descriptor, Nothing} = nothing,
8 changes: 4 additions & 4 deletions src/vector.jl
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...)
Create a GBVector from a vector of indices `I` and a vector of values `X`.
"""
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T}
x = GBVector{T}(maximum(I))
function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T}
x = GBVector{T}(nrows)
build(x, I, X, dup = dup)
return x
end
@@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x.
The resulting vector is "iso-valued" such that it only stores `x` once rather than once for
each index.
"""
function GBVector(I::Vector, x::T;
function GBVector(I::AbstractVector, x::T;
nrows = maximum(I)) where {T}
A = GBVector{T}(nrows)
build(A, I, x)
return A
end

function build(A::GBVector{T}, I::Vector, x::T) where {T}
function build(A::GBVector{T}, I::AbstractVector, x::T) where {T}
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements"))
x = GBScalar(x)

17 changes: 17 additions & 0 deletions test/chainrules/chainrulesutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using FiniteDifferences
function test_to_vec(x::T; check_inferred=true) where {T}
check_inferred && @inferred FiniteDifferences.to_vec(x)
x_vec, back = FiniteDifferences.to_vec(x)
@test x_vec isa Vector
@test all(s -> s isa Real, x_vec)
check_inferred && @inferred back(x_vec)
@test x == back(x_vec)
return nothing
end

@testset "chainrulesutils" begin
y = GBMatrix(sprand(10, 10, 0.5))
test_to_vec(y)
v = GBVector(sprand(10, 0.5))
test_to_vec(v)
end
32 changes: 32 additions & 0 deletions test/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@testset "Elementwise" begin
@testset "Dense" begin
@testset "Arithmetic Semiring" begin
#dense first
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
X = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end

@testset "Sparse" begin
@testset "Arithmetic Semiring" begin
Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector)
X = GBMatrix(sprand(10, 0.5))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end
end
21 changes: 21 additions & 0 deletions test/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@testset "mul" begin
@testset "Dense" begin
@testset "Arithmetic Semiring" begin
M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10))
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
end
end

@testset "Sparse" begin
M = GBMatrix(sprand(100, 10, 0.25))
Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector)
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
end
end
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS
using SparseArrays
using Test
using Random

using ChainRulesTestUtils
Random.seed!(1)

function include_test(path)
@@ -12,6 +12,10 @@ end

println("Testing SuiteSparseGraphBLAS.jl")
@testset "SuiteSparseGraphBLAS" begin

include_test("gbarray.jl")
include_test("operations.jl")
include_test("chainrules/chainrulesutils.jl")
include_test("chainrules/mulrules.jl")
include_test("chainrules/mulrules.jl")
end