Skip to content

Commit c1ca0d3

Browse files
authored
Generalize Diagonal * AdjOrTransAbsMat to arbitrary element types (#52389)
1 parent 8d0eec9 commit c1ca0d3

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -320,15 +320,6 @@ end
320320
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
321321
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
322322

323-
function (*)(A::AdjOrTransAbsMat, D::Diagonal)
324-
Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))
325-
rmul!(Ac, D)
326-
end
327-
function (*)(D::Diagonal, A::AdjOrTransAbsMat)
328-
Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))
329-
lmul!(D, Ac)
330-
end
331-
332323
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
333324
require_one_based_indexing(out, B)
334325
alpha, beta = _add.alpha, _add.beta

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ using .Main.InfiniteArrays
1818
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
1919
using .Main.FillArrays
2020

21+
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
22+
using .Main.SizedArrays
23+
2124
const n=12 # Size of matrix problem to test
2225
Random.seed!(1)
2326

@@ -778,6 +781,11 @@ end
778781
D = Diagonal(fill(M, n))
779782
@test D == Matrix{eltype(D)}(D)
780783
end
784+
785+
S = SizedArray{(2,3)}(reshape([1:6;],2,3))
786+
D = Diagonal(fill(S,3))
787+
@test D * fill(S,2,3)' == fill(S * S', 3, 2)
788+
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)
781789
end
782790

783791
@testset "Eigensystem for block diagonal (issue #30681)" begin

test/testhelpers/SizedArrays.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ module SizedArrays
99

1010
import Base: +, *, ==
1111

12+
using LinearAlgebra
13+
1214
export SizedArray
1315

1416
struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
@@ -31,9 +33,16 @@ Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
3133
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
3234
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
3335
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
34-
function *(S1::SizedArray, S2::SizedArray)
36+
37+
const SizedArrayLike = Union{SizedArray, Transpose{<:Any, <:SizedArray}, Adjoint{<:Any, <:SizedArray}}
38+
39+
_data(S::SizedArray) = S.data
40+
_data(T::Transpose{<:Any, <:SizedArray}) = transpose(_data(parent(T)))
41+
_data(T::Adjoint{<:Any, <:SizedArray}) = adjoint(_data(parent(T)))
42+
43+
function *(S1::SizedArrayLike, S2::SizedArrayLike)
3544
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
36-
data = S1.data * S2.data
45+
data = _data(S1) * _data(S2)
3746
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
3847
SizedArray{SZ}(data)
3948
end

0 commit comments

Comments
 (0)