Skip to content

Commit 63bd1fa

Browse files
author
Clément Fauchereau
committed
fast extrema computation on sparse arrays
1 parent bf886b5 commit 63bd1fa

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

stdlib/SparseArrays/src/higherorderfns.jl

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module HigherOrderFns
44

55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7-
import Base: map, map!, broadcast, copy, copyto!
7+
import Base: map, map!, broadcast, copy, copyto!, _extrema_dims, _extrema_itr
88

99
using Base: front, tail, to_shape
1010
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrixCSC,
@@ -29,6 +29,7 @@ using LinearAlgebra
2929
# (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices,
3030
# structured matrices, and one- and two-dimensional Arrays.
3131
# (12) Define map[!] methods handling combinations of sparse and structured matrices.
32+
# (13) Define extrema methods optimized for sparse vectors/matrices.
3233

3334

3435
# (0) BroadcastStyle rules and convenience types for dispatch
@@ -1154,4 +1155,58 @@ map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N})
11541155
map!(f::Tf, C::AbstractSparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
11551156
(_checksameshape(C, A, Bs...); _noshapecheck_map!(f, C, _sparsifystructured(A), map(_sparsifystructured, Bs)...))
11561157

1158+
1159+
# (13) extrema methods optimized for sparse vectors/matrices.
1160+
function _extrema_itr(f, A::SparseVecOrMat)
1161+
M = length(A)
1162+
iszero(M) && throw(ArgumentError("Sparse array must have at least one element."))
1163+
N = nnz(A)
1164+
iszero(N) && return f(zero(eltype(A))), f(zero(eltype(A)))
1165+
vmin, vmax = _extrema_itr(f, nonzeros(A))
1166+
if N != M
1167+
f0 = f(zero(eltype(A)))
1168+
vmin = min(f0, vmin)
1169+
vmax = max(f0, vmax)
1170+
end
1171+
vmin, vmax
1172+
end
1173+
1174+
function _extrema_dims(f, x::SparseVector{Tv, Ti}, dims) where {Tv, Ti}
1175+
sz = zeros(1)
1176+
for d in dims
1177+
sz[d] = 1
1178+
end
1179+
if sz == [1] && !iszero(length(x))
1180+
return [_extrema_itr(f, x)]
1181+
end
1182+
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, x, dims)
1183+
end
1184+
1185+
function _extrema_dims(f, A::AbstractSparseMatrix{Tv, Ti}, dims) where {Tv, Ti}
1186+
sz = zeros(2)
1187+
for d in dims
1188+
sz[d] = 1
1189+
end
1190+
if sz == [1, 0] && !iszero(length(A))
1191+
B = Array{Tuple{Tv,Tv}}(undef, 1, size(A, 2))
1192+
@inbounds for col_idx in 1:size(A, 2)
1193+
col = @view A[:,col_idx]
1194+
fx = (nnz(col) == size(A, 1)) ? f(A[1,col_idx]) : f(zero(Tv))
1195+
B[col_idx] = (fx, fx)
1196+
for x in nonzeros(col)
1197+
fx = f(x)
1198+
if fx < B[col_idx][1]
1199+
B[col_idx] = (fx, B[col_idx][2])
1200+
elseif fx > B[col_idx][2]
1201+
B[col_idx] = (B[col_idx][1], fx)
1202+
end
1203+
end
1204+
end
1205+
return B
1206+
end
1207+
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, A, dims)
1208+
end
1209+
1210+
_extrema_dims(f, A::SparseVecOrMat, ::Colon) = _extrema_itr(f, A)
1211+
11571212
end

stdlib/SparseArrays/test/higherorderfns.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,4 +687,37 @@ end
687687
@test SparseMatStyle(Val(3)) == Broadcast.DefaultArrayStyle{3}()
688688
end
689689

690+
@testset "extrema" begin
691+
n = 10
692+
A = sprand(n, n, 0.2)
693+
B = Array(A)
694+
C = Array{Real}(undef, 0, 0)
695+
x = sprand(n, 0.2)
696+
y = Array(x)
697+
z = Array{Real}(undef, 0)
698+
f(x) = x^3
699+
@test extrema(A) == extrema(B)
700+
@test extrema(x) == extrema(y)
701+
@test extrema(f, A) == extrema(f, B)
702+
@test extrema(f, x) == extrema(f, y)
703+
@test extrema(spzeros(n, n)) == (0.0, 0.0)
704+
@test extrema(spzeros(n)) == (0.0, 0.0)
705+
@test_throws ArgumentError extrema(spzeros(0, 0))
706+
@test_throws ArgumentError extrema(spzeros(0))
707+
@test extrema(sparse(ones(n, n))) == (1.0, 1.0)
708+
@test extrema(sparse(ones(n))) == (1.0, 1.0)
709+
@test extrema(A; dims=:) == extrema(B; dims=:)
710+
@test extrema(A; dims=1) == extrema(B; dims=1)
711+
@test extrema(A; dims=2) == extrema(B; dims=2)
712+
@test extrema(A; dims=(1,2)) == extrema(B; dims=(1,2))
713+
@test extrema(f, A; dims=1) == extrema(f, B; dims=1)
714+
@test extrema(sparse(C); dims=1) == extrema(C; dims=1)
715+
@test extrema(A; dims=[]) == extrema(B; dims=[])
716+
@test extrema(x; dims=:) == extrema(y; dims=:)
717+
@test extrema(x; dims=1) == extrema(y; dims=1)
718+
@test extrema(f, x; dims=1) == extrema(f, y; dims=1)
719+
@test_throws BoundsError extrema(sparse(z); dims=1)
720+
@test extrema(x; dims=[]) == extrema(y; dims=[])
721+
end
722+
690723
end # module

0 commit comments

Comments
 (0)