Skip to content

Commit 00cac74

Browse files
authored
Generalize sparse matmul (#62)
1 parent 6738bee commit 00cac74

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.10"
4+
version = "0.5.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
8+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
910
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1011
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
@@ -15,6 +16,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1516

1617
[compat]
1718
Accessors = "0.1.41"
19+
Adapt = "4.3.0"
1820
Aqua = "0.8.9"
1921
ArrayLayouts = "1.11.0"
2022
DerivableInterfaces = "0.5"

src/abstractsparsearray.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@ function Base._cat(dims, a::AnyAbstractSparseArray...)
3535
return concatenate(dims, a...)
3636
end
3737

38+
function map_stored(f, a::AnyAbstractSparseArray)
39+
kvs = storedpairs(a)
40+
# `collect` to convert to `Vector`, since otherwise
41+
# if it stays as `Dictionary` we might hit issues like
42+
# https://github.com/andyferris/Dictionaries.jl/issues/163.
43+
ks = collect(first.(kvs))
44+
vs = collect(last.(kvs))
45+
vs′ = map(f, vs)
46+
a′ = zero!(similar(a, eltype(vs′)))
47+
for (k, v′) in zip(ks, vs′)
48+
a′[k] = v′
49+
end
50+
return a′
51+
end
52+
53+
using Adapt: adapt
54+
function Base.print_array(io::IO, a::AnyAbstractSparseArray)
55+
a′ = map_stored(adapt(Array), a)
56+
return @invoke Base.print_array(io::typeof(io), a′::AbstractArray{<:Any,ndims(a)})
57+
end
3858
function Base.replace_in_print_matrix(
3959
a::AnyAbstractSparseVecOrMat, i::Integer, j::Integer, s::AbstractString
4060
)

src/abstractsparsearrayinterface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,11 @@ function sparse_mul!(
256256
for I2 in eachstoredindex(a2)
257257
I_dest = mul_indices(I1, I2)
258258
if !isnothing(I_dest)
259-
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β′)
259+
if isstored(a_dest, I_dest)
260+
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β′)
261+
else
262+
a_dest[I_dest] = a1[I1] * a2[I2] * α
263+
end
260264
end
261265
end
262266
end

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
44
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
5-
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
65
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
76
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
87
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -17,7 +16,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1716
Adapt = "4.2.0"
1817
Aqua = "0.8.11"
1918
ArrayLayouts = "1.11.1"
20-
DerivableInterfaces = "0.5"
2119
Dictionaries = "0.4.4"
2220
JLArrays = "0.2.0"
2321
LinearAlgebra = "<0.0.1, 1"

0 commit comments

Comments
 (0)