Skip to content

fix #37312 fast extrema computation on sparse arrays #37429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module HigherOrderFns

# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, copy, copyto!
import Base: map, map!, broadcast, copy, copyto!, _extrema_dims, _extrema_itr

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrixCSC,
Expand All @@ -29,6 +29,7 @@ using LinearAlgebra
# (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices,
# structured matrices, and one- and two-dimensional Arrays.
# (12) Define map[!] methods handling combinations of sparse and structured matrices.
# (13) Define extrema methods optimized for sparse vectors/matrices.


# (0) BroadcastStyle rules and convenience types for dispatch
Expand Down Expand Up @@ -1154,4 +1155,60 @@ map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N})
map!(f::Tf, C::AbstractSparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
(_checksameshape(C, A, Bs...); _noshapecheck_map!(f, C, _sparsifystructured(A), map(_sparsifystructured, Bs)...))


# (13) extrema methods optimized for sparse vectors/matrices.
function _extrema_itr(f, A::SparseVecOrMat)
M = length(A)
iszero(M) && throw(ArgumentError("Sparse array must have at least one element."))
N = nnz(A)
iszero(N) && return f(zero(eltype(A))), f(zero(eltype(A)))
vmin, vmax = _extrema_itr(f, nonzeros(A))
if N != M
f0 = f(zero(eltype(A)))
vmin = min(f0, vmin)
vmax = max(f0, vmax)
end
vmin, vmax
end

function _extrema_dims(f, x::SparseVector, dims)
sz = zeros(1)
for d in dims
sz[d] = 1
end
if sz == [1] && !iszero(length(x))
return [_extrema_itr(f, x)]
end
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, x, dims)
end

function _extrema_dims(f, A::AbstractSparseMatrix, dims)
sz = zeros(2)
for d in dims
sz[d] = 1
end
if sz == [1, 0] && !iszero(length(A))
T = eltype(A)
B = Array{Tuple{T,T}}(undef, 1, size(A, 2))
@inbounds for col_idx in 1:size(A, 2)
col = @view A[:,col_idx]
fx = (nnz(col) == size(A, 1)) ? f(A[1,col_idx]) : f(zero(T))
B[col_idx] = (fx, fx)
for x in nonzeros(col)
fx = f(x)
if fx < B[col_idx][1]
B[col_idx] = (fx, B[col_idx][2])
elseif fx > B[col_idx][2]
B[col_idx] = (B[col_idx][1], fx)
end
end
end
return B
end
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, A, dims)
end

_extrema_dims(f, A::SparseVector, ::Colon) = _extrema_itr(f, A)
_extrema_dims(f, A::AbstractSparseMatrix, ::Colon) = _extrema_itr(f, A)

end
33 changes: 33 additions & 0 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,4 +687,37 @@ end
@test SparseMatStyle(Val(3)) == Broadcast.DefaultArrayStyle{3}()
end

@testset "extrema" begin
n = 10
A = sprand(n, n, 0.2)
B = Array(A)
C = Array{Real}(undef, 0, 0)
x = sprand(n, 0.2)
y = Array(x)
z = Array{Real}(undef, 0)
f(x) = x^3
@test extrema(A) == extrema(B)
@test extrema(x) == extrema(y)
@test extrema(f, A) == extrema(f, B)
@test extrema(f, x) == extrema(f, y)
@test extrema(spzeros(n, n)) == (0.0, 0.0)
@test extrema(spzeros(n)) == (0.0, 0.0)
@test_throws ArgumentError extrema(spzeros(0, 0))
@test_throws ArgumentError extrema(spzeros(0))
@test extrema(sparse(ones(n, n))) == (1.0, 1.0)
@test extrema(sparse(ones(n))) == (1.0, 1.0)
@test extrema(A; dims=:) == extrema(B; dims=:)
@test extrema(A; dims=1) == extrema(B; dims=1)
@test extrema(A; dims=2) == extrema(B; dims=2)
@test extrema(A; dims=(1,2)) == extrema(B; dims=(1,2))
@test extrema(f, A; dims=1) == extrema(f, B; dims=1)
@test extrema(sparse(C); dims=1) == extrema(C; dims=1)
@test extrema(A; dims=[]) == extrema(B; dims=[])
@test extrema(x; dims=:) == extrema(y; dims=:)
@test extrema(x; dims=1) == extrema(y; dims=1)
@test extrema(f, x; dims=1) == extrema(f, y; dims=1)
@test_throws BoundsError extrema(sparse(z); dims=1)
@test extrema(x; dims=[]) == extrema(y; dims=[])
end

end # module