Skip to content

Commit 6f5c862

Browse files
committed
axes(::SDiagonal) is statically sized
vector indexing for SDiagonal produces SArrays
1 parent 29a76ec commit 6f5c862

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

src/SDiagonal.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} = Diagonal(diag(a))
1717
size(::Type{<:SDiagonal{N}}) where {N} = (N,N)
1818
size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N
1919

20+
Base.axes(D::SDiagonal) = (ax = axes(diag(D), 1); (ax, ax))
21+
Base.axes(D::SDiagonal, d) = d <= 2 ? axes(D)[d] : SOneTo(1)
22+
23+
Base.reshape(a::SDiagonal, s::Tuple{SOneTo,Vararg{SOneTo}}) = reshape(a, homogenize_shape(s))
24+
2025
# define specific methods to avoid allocating mutable arrays
2126
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
2227
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity
@@ -56,3 +61,5 @@ function inv(D::SDiagonal)
5661
check_singular(D)
5762
SDiagonal(inv.(D.diag))
5863
end
64+
65+
Base.copy(D::SDiagonal) = Diagonal(copy(diag(D)))

src/indexing.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,29 @@ Base.unsafe_view(A::AbstractArray, i1::StaticIndexing, indices::StaticIndexing..
377377
# the tuple indices has to have at least one element to prevent infinite
378378
# recursion when viewing a zero-dimensional array (see issue #705)
379379
Base.SubArray(A::AbstractArray, indices::Tuple{StaticIndexing, Vararg{StaticIndexing}}) = Base.SubArray(A, map(unwrap, indices))
380+
381+
###########################################################
382+
# SDiagonal
383+
###########################################################
384+
385+
# SDiagonal uses Cartesian indexing, and the canonical indexing methods shadow getindex for Diagonal
386+
# these are needed for ambiguity resolution
387+
@inline function getindex(D::SDiagonal, i::Int, j::Int)
388+
@boundscheck checkbounds(D, i, j)
389+
if i == j
390+
@inbounds r = diag(D)[i]
391+
else
392+
r = LinearAlgebra.diagzero(D, i, j)
393+
end
394+
r
395+
end
396+
@inline function getindex(D::SDiagonal, i::Int...)
397+
@boundscheck checkbounds(D, i...)
398+
@inbounds r = D[eachindex(D)[i...]]
399+
r
400+
end
401+
# Ensure that vector indexing with static types lead to SArrays
402+
@propagate_inbounds function getindex(a::SDiagonal, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
403+
ar = reshape(a, Val(length(inds)))
404+
_getindex(ar, index_sizes(Size(ar), inds...), inds)
405+
end

test/SDiagonal.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ using StaticArrays, Test, LinearAlgebra
7070

7171
@test length(m) === 4*4
7272

73+
m2 = SMatrix{4,4}(m)
74+
@test axes(m) === axes(m2)
75+
@test axes(m, 1) === axes(m2, 1)
76+
@test axes(m, 3) == SOneTo(1)
77+
78+
@test m[:, 1] === SVector{4}(m[1,1], 0, 0, 0)
79+
@test m[:, :] === m2
80+
@test m[2, 2, 1] === m[2, 2]
81+
7382
@test_throws Exception m[1] = 1
7483

7584
b = @SVector [2,-1,2,1]
@@ -114,5 +123,7 @@ using StaticArrays, Test, LinearAlgebra
114123

115124
@test m + zero(m) == m
116125
@test m + zero(typeof(m)) == m
126+
127+
@test copy(m) === m
117128
end
118129
end

0 commit comments

Comments
 (0)