Skip to content

Commit d0f24c7

Browse files
authored
Tidy one() implementation and UniformScaling constructors (#690)
* Remove all reference to the deprecated name `eye` (replace internal _eye with _scaler_matrix) * Remove all the more-or-less duplicate definitions of one() and use _scalar_matrix as the implementation. * Fix StaticMatrixLike definition to properly include Hermitian and Symmetric wrappers of static matrices. * Import LinearAlgebra.checksquare() and use it directly. * Remove obsolete checksquare specializations & add tests * Introduce _construct_similar to alleviate the difficulty of computing eltype purely in the type domain with non-concrete UnionAlls.
1 parent a3bca35 commit d0f24c7

14 files changed

+89
-73
lines changed

src/MArray.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,6 @@ end
7373

7474
@inline MArray(a::StaticArray) = MArray{size_tuple(Size(a))}(Tuple(a))
7575

76-
# Simplified show for the type
77-
#show(io::IO, ::Type{MArray{S, T, N}}) where {S, T, N} = print(io, "MArray{$S,$T,$N}")
78-
79-
# Some more advanced constructor-like functions
80-
@inline one(::Type{MArray{S}}) where {S} = one(MArray{S,Float64,tuple_length(S)})
81-
@inline one(::Type{MArray{S,T}}) where {S,T} = one(MArray{S,T,tuple_length(S)})
82-
83-
# MArray(I::UniformScaling) methods to replace eye
84-
(::Type{MA})(I::UniformScaling) where {MA<:MArray} = _eye(Size(MA), MA, I)
85-
8676
####################
8777
## MArray methods ##
8878
####################

src/MVector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ macro MVector(ex)
9595
error("@MVector expected a 1-dimensional array expression")
9696
end
9797
else
98-
error("@MVector only supports the zeros(), ones(), rand(), randn(), randexp(), and eye() functions.")
98+
error("@MVector only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
9999
end
100100
else
101101
error("Use @MVector [a,b,c] or @MVector([a,b,c])")

src/SArray.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,6 @@ end
5252

5353
@inline SArray(a::StaticArray) = SArray{size_tuple(Size(a))}(Tuple(a))
5454

55-
# Simplified show for the type
56-
# show(io::IO, ::Type{SArray{S, T, N}}) where {S, T, N} = print(io, "SArray{$S,$T,$N}") # TODO reinstate
57-
58-
# Some more advanced constructor-like functions
59-
@inline one(::Type{SArray{S}}) where {S} = one(SArray{S, Float64, tuple_length(S)})
60-
@inline one(::Type{SArray{S, T}}) where {S, T} = one(SArray{S, T, tuple_length(S)})
61-
62-
# SArray(I::UniformScaling) methods to replace eye
63-
(::Type{SA})(I::UniformScaling) where {SA<:SArray} = _eye(Size(SA), SA, I)
64-
6555
####################
6656
## SArray methods ##
6757
####################

src/SDiagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N
3333
# override to avoid copying
3434
diag(D::SDiagonal) = D.diag
3535

36-
# SDiagonal(I::UniformScaling) methods to replace eye
36+
# SDiagonal(I::UniformScaling) methods
3737
(::Type{SDiagonal{N}})(I::UniformScaling) where {N} = SDiagonal{N}(ntuple(x->I.λ, Val(N)))
3838
(::Type{SDiagonal{N,T}})(I::UniformScaling) where {N,T} = SDiagonal{N,T}(ntuple(x->I.λ, Val(N)))
3939

src/SHermitianCompact.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ end
230230
end
231231
end
232232

233-
@inline _eye(s::Size{S}, t::Type{SSC}) where {S, SSC <: SHermitianCompact} = _one(s, t)
233+
@inline _scalar_matrix(s::Size{S}, t::Type{SSC}) where {S, SSC <: SHermitianCompact} = _one(s, t)
234234

235235
# _fill covers fill, zeros, and ones:
236236
@generated function _fill(val, ::Size{s}, ::Type{SSC}) where {s, SSC <: SHermitianCompact}

src/SVector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ macro SVector(ex)
108108
error("@SVector expected a 1-dimensional array expression")
109109
end
110110
else
111-
error("@SVector only supports the zeros(), ones(), rand(), randn(), randexp(), and eye() functions.")
111+
error("@SVector only supports the zeros(), ones(), rand(), randn() and randexp() functions.")
112112
end
113-
else # TODO Expr(:call, :zeros), Expr(:call, :ones), Expr(:call, :eye) ?
113+
else
114114
error("Use @SVector [a,b,c], @SVector Type[a,b,c] or a comprehension like [f(i) for i = i_min:i_max]")
115115
end
116116
end

src/StaticArrays.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr,
1919
kron, diag, norm, dot, diagm, lu, svd, svdvals,
2020
factorize, ishermitian, issymmetric, isposdef, normalize,
2121
normalize!, Eigen, det, logdet, cross, diff, qr, \
22-
23-
# import eye for deprecation warnings
24-
@static if isdefined(LinearAlgebra, :eye)
25-
import LinearAlgebra: eye
26-
end
22+
using LinearAlgebra: checksquare
2723

2824
export SOneTo
2925
export StaticScalar, StaticArray, StaticVector, StaticMatrix
@@ -88,8 +84,8 @@ const StaticMatrixLike{T} = Union{
8884
StaticMatrix{<:Any, <:Any, T},
8985
Transpose{T, <:StaticVecOrMat{T}},
9086
Adjoint{T, <:StaticVecOrMat{T}},
91-
Symmetric{T, <:StaticMatrix{T}},
92-
Hermitian{T, <:StaticMatrix{T}},
87+
Symmetric{T, <:StaticMatrix{<:Any, <:Any, T}},
88+
Hermitian{T, <:StaticMatrix{<:Any, <:Any, T}},
9389
Diagonal{T, <:StaticVector{<:Any, T}}
9490
}
9591
const StaticVecOrMatLike{T} = Union{StaticVector{<:Any, T}, StaticMatrixLike{T}}

src/abstractarray.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,27 @@ mutable_similar_type(::Type{T}, s::Size{S}, ::Type{Val{D}}) where {T,S,D} = MArr
9393

9494
sizedarray_similar_type(::Type{T},s::Size{S},::Type{Val{D}}) where {T,S,D} = SizedArray{Tuple{S...},T,D,length(s)}
9595

96+
# Utility for computing the eltype of an array instance, type, or type
97+
# constructor. For type constructors without a definite eltype, the default
98+
# value is returned.
99+
Base.@pure _eltype_or(a::AbstractArray, default) = eltype(a)
100+
Base.@pure _eltype_or(::Type{<:AbstractArray{T}}, default) where {T} = T
101+
Base.@pure _eltype_or(::Type{<:AbstractArray}, default) = default # eltype not available
102+
103+
"""
104+
_construct_similar(a, ::Size, elements::NTuple)
105+
106+
Construct a static array of similar type to `a` with the given `elements`.
107+
108+
When `a` is an instance or a concrete type the element type `eltype(a)` is
109+
used. However, when `a` is a `UnionAll` type such as `SMatrix{2,2}`, the
110+
promoted type of `elements` is used instead.
111+
"""
112+
@inline function _construct_similar(a, s::Size, elements::NTuple{L,ET}) where {L,ET}
113+
similar_type(a, _eltype_or(a, ET), s)(elements)
114+
end
115+
116+
96117
# Field vectors are user controlled, and currently default to SVector, etc
97118

98119
"""

src/det.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
end
4444

4545
@generated function _det(::Size{S}, A::StaticMatrix) where S
46-
LinearAlgebra.checksquare(A)
46+
checksquare(A)
4747
if prod(S) 14*14
4848
quote
4949
@_inline_meta
@@ -58,7 +58,7 @@ end
5858
@inline logdet(A::StaticMatrix) = _logdet(Size(A), A)
5959
@inline _logdet(::Union{Size{(1,1)}, Size{(2,2)}, Size{(3,3)}, Size{(4,4)}}, A::StaticMatrix) = log(det(A))
6060
@generated function _logdet(::Size{S}, A::StaticMatrix) where S
61-
LinearAlgebra.checksquare(A)
61+
checksquare(A)
6262
if prod(S) 14*14
6363
quote
6464
@_inline_meta

src/linalg.jl

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import Base: +, -, *, /, \
22

3-
# TODO: more operators, like AbstractArray
3+
#--------------------------------------------------
4+
# Vector space algebra
45

56
# Unary ops
67
@inline -(a::StaticArray) = map(-, a)
@@ -30,10 +31,7 @@ import Base: +, -, *, /, \
3031
@inline -(a::UniformScaling, b::StaticMatrix) = _plus_uniform(Size(b), -b, a.λ)
3132

3233
@generated function _plus_uniform(::Size{S}, a::StaticMatrix, λ) where {S}
33-
if S[1] != S[2]
34-
throw(DimensionMismatch("matrix is not square: dimensions are $S"))
35-
end
36-
n = S[1]
34+
n = checksquare(a)
3735
exprs = [i == j ? :(a[$(LinearIndices(S)[i, j])] + λ) : :(a[$(LinearIndices(S)[i, j])]) for i = 1:n, j = 1:n]
3836
return quote
3937
$(Expr(:meta, :inline))
@@ -46,6 +44,8 @@ end
4644
@inline \(a::UniformScaling, b::Union{StaticMatrix,StaticVector}) = a.λ \ b
4745
@inline /(a::StaticMatrix, b::UniformScaling) = a / b.λ
4846

47+
#--------------------------------------------------
48+
# Matrix algebra
4949

5050
# Transpose, conjugate, etc
5151
@inline conj(a::StaticArray) = map(conj, a)
@@ -85,31 +85,30 @@ end
8585
@inline Base.zero(a::SA) where {SA <: StaticArray} = zeros(SA)
8686
@inline Base.zero(a::Type{SA}) where {SA <: StaticArray} = zeros(SA)
8787

88-
@inline one(::SM) where {SM <: StaticMatrix} = _one(Size(SM), SM)
89-
@inline one(::Type{SM}) where {SM <: StaticMatrix} = _one(Size(SM), SM)
90-
@generated function _one(::Size{S}, ::Type{SM}) where {S, SM <: StaticArray}
91-
if (length(S) != 2) || (S[1] != S[2])
92-
error("multiplicative identity defined only for square matrices")
93-
end
94-
T = eltype(SM) # should be "hyperpure"
95-
if T == Any
96-
T = Float64
97-
end
98-
exprs = [i == j ? :(one($T)) : :(zero($T)) for i 1:S[1], j 1:S[2]]
99-
return quote
100-
$(Expr(:meta, :inline))
101-
SM(tuple($(exprs...)))
88+
@inline one(m::StaticMatrixLike) = _one(Size(m), m)
89+
@inline one(::Type{SM}) where {SM<:StaticMatrixLike}= _one(Size(SM), SM)
90+
function _one(s::Size, m_or_SM)
91+
if (length(s) != 2) || (s[1] != s[2])
92+
throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
10293
end
94+
_scalar_matrix(s, m_or_SM, one(_eltype_or(m_or_SM, Float64)))
10395
end
10496

105-
# StaticMatrix(I::UniformScaling) methods to replace eye
106-
(::Type{SM})(I::UniformScaling) where {N,M,SM<:StaticMatrix{N,M}} = _eye(Size(SM), SM, I)
107-
108-
@generated function _eye(::Size{S}, ::Type{SM}, I::UniformScaling{T}) where {S, SM <: StaticArray, T}
109-
exprs = [i == j ? :(I.λ) : :(zero($T)) for i 1:S[1], j 1:S[2]]
97+
# StaticMatrix(I::UniformScaling)
98+
(::Type{SM})(I::UniformScaling) where {SM<:StaticMatrix} = _scalar_matrix(Size(SM), SM, I.λ)
99+
# The following oddity is needed if we want `SArray{Tuple{2,3}}(I)` to work
100+
# because we do not have `SArray{Tuple{2,3}} <: StaticMatrix`.
101+
(::Type{SM})(I::UniformScaling) where {SM<:(StaticArray{Tuple{N,M}} where {N,M})} =
102+
_scalar_matrix(Size(SM), SM, I.λ)
103+
104+
# Construct a matrix with the scalar λ on the diagonal and zeros off the
105+
# diagonal. The matrix can be non-square.
106+
@generated function _scalar_matrix(s::Size{S}, m_or_SM, λ) where {S}
107+
elements = Symbol[i == j ? : :λzero for i in 1:S[1], j in 1:S[2]]
110108
return quote
111109
$(Expr(:meta, :inline))
112-
SM(tuple($(exprs...)))
110+
λzero = zero(λ)
111+
_construct_similar(m_or_SM, s, tuple($(elements...)))
113112
end
114113
end
115114

@@ -145,6 +144,8 @@ end
145144
end
146145
end
147146

147+
#--------------------------------------------------
148+
# Vector products
148149
@inline cross(a::StaticVector, b::StaticVector) = _cross(same_size(a, b), a, b)
149150
_cross(::Size{S}, a::StaticVector, b::StaticVector) where {S} = error("Cross product not defined for $(S[1])-vectors")
150151
@inline function _cross(::Size{(2,)}, a::StaticVector, b::StaticVector)
@@ -179,6 +180,8 @@ end
179180
return ret
180181
end
181182

183+
#--------------------------------------------------
184+
# Norms
182185
@inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v))))
183186

184187
@inline norm(a::StaticArray) = _norm(Size(a), a)
@@ -240,9 +243,7 @@ end
240243

241244
@inline tr(a::StaticMatrix) = _tr(Size(a), a)
242245
@generated function _tr(::Size{S}, a::StaticMatrix) where {S}
243-
if S[1] != S[2]
244-
throw(DimensionMismatch("matrix is not square"))
245-
end
246+
checksquare(a)
246247

247248
if S[1] == 0
248249
return :(zero(eltype(a)))
@@ -257,6 +258,10 @@ end
257258
end
258259
end
259260

261+
262+
#--------------------------------------------------
263+
# Outer products
264+
260265
const _length_limit = Length(200)
261266

262267
@inline kron(a::StaticMatrix, b::StaticMatrix) = _kron(_length_limit, Size(a), Size(b), a, b)
@@ -414,11 +419,9 @@ end
414419
end
415420
end
416421

417-
# some micro-optimizations (TODO check these make sense for v0.6+)
418-
@inline LinearAlgebra.checksquare(::SM) where {SM<:StaticMatrix} = _checksquare(Size(SM))
419-
@inline LinearAlgebra.checksquare(::Type{SM}) where {SM<:StaticMatrix} = _checksquare(Size(SM))
420422

421-
@pure _checksquare(::Size{S}) where {S} = (S[1] == S[2] || throw(DimensionMismatch("matrix is not square: dimensions are $S")); S[1])
423+
#--------------------------------------------------
424+
# Some shimming for special linear algebra matrix types
425+
@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo))
426+
@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo))
422427

423-
@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo))
424-
@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo))

src/qr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ end
9898
Q = [Symbol("Q_$(i)_$(j)") for i = 1:m, j = 1:m]
9999
R = [Symbol("R_$(i)_$(j)") for i = 1:m, j = 1:n]
100100

101-
initQ = [:($(Q[i, j]) = $(i == j ? one : zero)(T)) for i = 1:m, j = 1:m] # Q .= eye(A)
101+
initQ = [:($(Q[i, j]) = $(i == j ? one : zero)(T)) for i = 1:m, j = 1:m] # Q .= I
102102
initR = [:($(R[i, j]) = T(A[$i, $j])) for i = 1:m, j = 1:n] # R .= A
103103

104104
code = quote end

src/util.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,4 @@ Base.@propagate_inbounds function invperm(p::StaticVector)
108108
ip[p] = 1:length(p)
109109
similar_type(p)(ip)
110110
end
111+

test/linalg.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using StaticArrays, Test, LinearAlgebra
22

3+
using LinearAlgebra: checksquare
4+
35
@testset "Linear algebra" begin
46

57
@testset "SArray as a (mathematical) vector space" begin
@@ -100,7 +102,7 @@ using StaticArrays, Test, LinearAlgebra
100102
@test @inferred(one(MMatrix{2,2}))::MMatrix == @MMatrix [1.0 0.0; 0.0 1.0]
101103
@test @inferred(one(MMatrix{2}))::MMatrix == @MMatrix [1.0 0.0; 0.0 1.0]
102104

103-
@test_throws ErrorException one(MMatrix{2,4})
105+
@test_throws DimensionMismatch one(MMatrix{2,4})
104106
end
105107

106108
@testset "cross()" begin
@@ -265,4 +267,12 @@ using StaticArrays, Test, LinearAlgebra
265267
@test @inferred(kron(A,transpose(b)))::SizedMatrix{10,210} == kron(P,transpose(q))
266268

267269
end
270+
271+
@testset "checksquare" begin
272+
m22 = SA[1 2; 3 4]
273+
@test @inferred(checksquare(m22)) === 2
274+
@test_inlined checksquare(m22)
275+
m23 = SA[1 2 3; 4 5 6]
276+
@test_inlined checksquare(m23) false
277+
end
268278
end

test/testutil.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Note that LLVM IR can contain `call` instructions to intrinsics which don't
5858
make it into the native code, so this can be overly eager in declaring a
5959
a lack of complete inlining.
6060
"""
61-
macro test_inlined(ex)
61+
macro test_inlined(ex, should_inline=true)
6262
ex_orig = ex
6363
ex = macroexpand(@__MODULE__, :(@code_llvm $ex))
6464
expr = quote
@@ -71,7 +71,10 @@ macro test_inlined(ex)
7171
# TODO: Figure out some better pattern matching; LLVM IR can contain
7272
# calls to intrinsics, so this will sometimes/often fail even when the
7373
# native code has no call instructions.
74-
@test !occursin("call", code_str)
74+
$(should_inline ?
75+
:(@test !occursin("call", code_str)) :
76+
:(@test occursin("call", code_str))
77+
)
7578
end
7679
@assert expr.args[4].head == :macrocall
7780
expr.args[4].args[2] = __source__
@@ -84,9 +87,11 @@ should_not_be_inlined(x) = _should_not_be_inlined(x)
8487

8588
@testset "@test_inlined" begin
8689
@test_inlined should_be_inlined(1)
90+
@test_inlined should_not_be_inlined(1) false
8791
ts = @testset ErrorCounterTestSet "" begin
92+
@test_inlined should_be_inlined(1) false
8893
@test_inlined should_not_be_inlined(1)
8994
end
90-
@test ts.errorcount == 0 && ts.failcount == 1 && ts.passcount == 0
95+
@test ts.errorcount == 0 && ts.failcount == 2 && ts.passcount == 0
9196
end
9297

0 commit comments

Comments
 (0)