Skip to content

Commit 4679ad1

Browse files
committed
Implement sparse matrix copy! for unequal length.
Instead of relying on copy!(::AbstractMatrix, ::AbstractMatrix), the dest matrix is prepared so that sparse_compute_reshaped_colptr_and_rowval() can then be applied.
1 parent 15afb1e commit 4679ad1

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

base/sparse/sparsematrix.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ end
193193
copy(S::SparseMatrixCSC) =
194194
SparseMatrixCSC(S.m, S.n, copy(S.colptr), copy(S.rowval), copy(S.nzval))
195195

196-
function copy!{TvA, TiA, TvB, TiB}(A::SparseMatrixCSC{TvA,TiA},
197-
B::SparseMatrixCSC{TvB,TiB})
196+
function copy!(A::SparseMatrixCSC, B::SparseMatrixCSC)
198197
# If the two matrices have the same length then all the
199198
# elements in A will be overwritten.
200199
if length(A) == length(B)
@@ -208,10 +207,37 @@ function copy!{TvA, TiA, TvB, TiB}(A::SparseMatrixCSC{TvA,TiA},
208207
# This is like a "reshape B into A".
209208
sparse_compute_reshaped_colptr_and_rowval(A.colptr, A.rowval, A.m, A.n, B.colptr, B.rowval, B.m, B.n)
210209
end
211-
copy!(A.nzval, B.nzval)
212210
else
213-
invoke(Base.copy!, Tuple{AbstractMatrix{TvA}, AbstractMatrix{TvB}}, A, B)
211+
length(A) >= length(B) || throw(BoundsError())
212+
lB = length(B)
213+
nnzA = nnz(A)
214+
nnzB = nnz(B)
215+
# Up to which col, row, and ptr in rowval/nzval will A be overwritten?
216+
lastmodcolA = div(lB - 1, A.m) + 1
217+
lastmodrowA = mod(lB - 1, A.m) + 1
218+
lastmodptrA = A.colptr[lastmodcolA]
219+
while lastmodptrA < A.colptr[lastmodcolA+1] && A.rowval[lastmodptrA] <= lastmodrowA
220+
lastmodptrA += 1
221+
end
222+
lastmodptrA -= 1
223+
if lastmodptrA >= nnzB
224+
# A will have fewer non-zero elements; unmodified elements are kept at the end.
225+
deleteat!(A.rowval, nnzB+1:lastmodptrA)
226+
deleteat!(A.nzval, nnzB+1:lastmodptrA)
227+
else
228+
# A will have more non-zero elements; unmodified elements are kept at the end.
229+
resize!(A.rowval, nnzB + nnzA - lastmodptrA)
230+
resize!(A.nzval, nnzB + nnzA - lastmodptrA)
231+
copy!(A.rowval, nnzB+1, A.rowval, lastmodptrA+1, nnzA-lastmodptrA)
232+
copy!(A.nzval, nnzB+1, A.nzval, lastmodptrA+1, nnzA-lastmodptrA)
233+
end
234+
# Adjust colptr accordingly.
235+
@inbounds for i in 2:length(A.colptr)
236+
A.colptr[i] += nnzB - lastmodptrA
237+
end
238+
sparse_compute_reshaped_colptr_and_rowval(A.colptr, A.rowval, A.m, lastmodcolA-1, B.colptr, B.rowval, B.m, B.n)
214239
end
240+
copy!(A.nzval, B.nzval)
215241
return A
216242
end
217243

test/sparsedir/sparse.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,26 @@ let
257257
B = sprand(25, 1, 0.2)
258258
copy!(A, B)
259259
@test A[:] == B[:]
260-
# Test size(A) != size(B)
261-
B = sprand(3, 3, 0.2)
262-
copy!(A, B)
263-
@test A[1:9] == B[:]
260+
# Test various size(A) / size(B) combinations
261+
for mA in [5, 10, 20], nA in [5, 10, 20], mB in [5, 10, 20], nB in [5, 10, 20]
262+
A = sprand(mA,nA,0.4)
263+
Aorig = copy(A)
264+
B = sprand(mB,nB,0.4)
265+
if mA*nA >= mB*nB
266+
copy!(A,B)
267+
@assert(A[1:length(B)] == B[:])
268+
@assert(A[length(B)+1:end] == Aorig[length(B)+1:end])
269+
else
270+
@test_throws BoundsError copy!(A,B)
271+
end
272+
end
264273
# Test eltype(A) != eltype(B), size(A) != size(B)
274+
A = sprand(5, 5, 0.2)
275+
Aorig = copy(A)
265276
B = sparse(rand(Float32, 3, 3))
266277
copy!(A, B)
267278
@test A[1:9] == B[:]
279+
@test A[10:end] == Aorig[10:end]
268280
# Test eltype(A) != eltype(B), size(A) == size(B)
269281
A = sparse(rand(Float64, 3, 3))
270282
B = sparse(rand(Float32, 3, 3))

0 commit comments

Comments
 (0)