Skip to content

Commit ff316b2

Browse files
authored
Merge pull request #39207 from JuliaLang/mh/sparse-type-stability
Fix some inferability issues in SparseArrays
2 parents 03ec87f + 1a302b9 commit ff316b2

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

stdlib/SparseArrays/src/sparsematrix.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,7 +2045,7 @@ function _findr(op, A, region, Tv)
20452045
throw(ArgumentError("array slices must be non-empty"))
20462046
else
20472047
ri = Base.reduced_indices0(A, region)
2048-
return (similar(A, ri), zeros(Ti, ri))
2048+
return (zeros(Tv, ri), zeros(Ti, ri))
20492049
end
20502050
end
20512051

@@ -3274,6 +3274,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :)
32743274

32753275
# Sparse concatenation
32763276

3277+
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti
3278+
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} =
3279+
promote_type(Ti, promote_idxtype(X...))
3280+
32773281
function vcat(X::AbstractSparseMatrixCSC...)
32783282
num = length(X)
32793283
mX = Int[ size(x, 1) for x in X ]
@@ -3288,7 +3292,7 @@ function vcat(X::AbstractSparseMatrixCSC...)
32883292
end
32893293

32903294
Tv = promote_eltype(X...)
3291-
Ti = promote_eltype(map(x->rowvals(x), X)...)
3295+
Ti = promote_idxtype(X...)
32923296

32933297
nnzX = Int[ nnz(x) for x in X ]
32943298
nnz_res = sum(nnzX)
@@ -3340,7 +3344,7 @@ function hcat(X::AbstractSparseMatrixCSC...)
33403344
n = sum(nX)
33413345

33423346
Tv = promote_eltype(X...)
3343-
Ti = promote_eltype(map(x->rowvals(x), X)...)
3347+
Ti = promote_idxtype(X...)
33443348

33453349
colptr = Vector{Ti}(undef, n+1)
33463350
nnzX = Int[ nnz(x) for x in X ]

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,17 +1091,16 @@ function vcat(Xin::_SparseConcatGroup...)
10911091
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
10921092
vcat(X...)
10931093
end
1094-
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1095-
nbr = length(rows) # number of block rows
1096-
1097-
tmp_rows = Vector{SparseMatrixCSC}(undef, nbr)
1098-
k = 0
1099-
@inbounds for i = 1 : nbr
1100-
tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...)
1101-
k += rows[i]
1094+
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
1095+
vcat(_hvcat_rows(rows, X...)...)
1096+
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1097+
if row1 0
1098+
throw(ArgumentError("length of block row must be positive, got $row1"))
11021099
end
1103-
vcat(tmp_rows...)
1100+
# provide X[1] separately to convince inference that we don't call hcat() without arguments
1101+
return (hcat(X[1], X[2 : row1]...), _hvcat_rows(rows, X[row1+1:end]...)...)
11041102
end
1103+
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()
11051104

11061105
# make sure UniformScaling objects are converted to sparse matrices for concatenation
11071106
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC

stdlib/SparseArrays/test/sparse.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ end
165165
sz34 = spzeros(3, 4)
166166
se77 = sparse(1.0I, 7, 7)
167167
@testset "h+v concatenation" begin
168-
@test [se44 sz42 sz41; sz34 se33] == se77
168+
@test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33]
169169
@test length(nonzeros([sp33 0I; 1I 0I])) == 6
170170
end
171171

@@ -1338,10 +1338,10 @@ end
13381338
@testset "argmax, argmin, findmax, findmin" begin
13391339
S = sprand(100,80, 0.5)
13401340
A = Array(S)
1341-
@test argmax(S) == argmax(A)
1342-
@test argmin(S) == argmin(A)
1343-
@test findmin(S) == findmin(A)
1344-
@test findmax(S) == findmax(A)
1341+
@test @inferred(argmax(S)) == argmax(A)
1342+
@test @inferred(argmin(S)) == argmin(A)
1343+
@test @inferred(findmin(S)) == findmin(A)
1344+
@test @inferred(findmax(S)) == findmax(A)
13451345
for region in [(1,), (2,), (1,2)], m in [findmax, findmin]
13461346
@test m(S, dims=region) == m(A, dims=region)
13471347
end
@@ -2201,7 +2201,7 @@ end
22012201
# Test that concatenations of pairs of sparse matrices yield sparse arrays
22022202
@test issparse(vcat(spmat, spmat))
22032203
@test issparse(hcat(spmat, spmat))
2204-
@test issparse(hvcat((2,), spmat, spmat))
2204+
@test issparse(@inferred(hvcat((2,), spmat, spmat)))
22052205
@test issparse(cat(spmat, spmat; dims=(1,2)))
22062206
# Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays
22072207
@test issparse(vcat(spmat, densemat))

0 commit comments

Comments
 (0)