Skip to content

BroadcastStyle for lazy triangular or HermOrSym, improved broadcast * #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 25, 2025
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "2.3.2"
version = "2.4"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
10 changes: 8 additions & 2 deletions ext/LazyArraysBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AbstractPaddedLayout, PaddedLayout, AbstractLazyBandedLayout, LazyBandedLayout, PaddedRows,
PaddedColumns, CachedArray, CachedMatrix, LazyLayout, BroadcastLayout, ApplyLayout,
paddeddata, resizedata!, broadcastlayout, _broadcastarray2broadcasted, _broadcast_sub_arguments,
arguments, call, applybroadcaststyle, simplify, simplifiable, islazy_layout, lazymaterialize, _broadcast_mul_mul,
arguments, call, applybroadcaststyle, simplify, simplifiable, islazy_layout, lazymaterialize, _broadcast_mul_mul, _broadcast_mul_simplifiable,
triangularlayout, AbstractCachedMatrix, _mulbanded_copyto!, ApplyBandedLayout, BroadcastBandedLayout
import Base: BroadcastStyle, similar, copy, broadcasted, getindex, OneTo, oneto, tail, sign, abs
import BandedMatrices: bandedbroadcaststyle, bandwidths, isbanded, bandedcolumns, bandeddata, BandedStyle,
Expand Down Expand Up @@ -539,6 +539,7 @@
const BandedLazyLayouts = Union{AbstractLazyBandedLayout, BandedColumns{LazyLayout}, BandedRows{LazyLayout},
TriangularLayout{UPLO,UNIT,BandedRows{LazyLayout}} where {UPLO,UNIT},
TriangularLayout{UPLO,UNIT,BandedColumns{LazyLayout}} where {UPLO,UNIT},
TriangularLayout{UPLO,UNIT,LazyBandedLayout} where {UPLO,UNIT},
SymTridiagonalLayout{LazyLayout}, BidiagonalLayout{LazyLayout}, TridiagonalLayout{LazyLayout},
SymmetricLayout{BandedColumns{LazyLayout}}, HermitianLayout{BandedColumns{LazyLayout}}}

Expand All @@ -551,7 +552,12 @@
copy(M::Mul{<:Any, <:BandedLazyLayouts}) = simplify(M)
copy(M::Mul{<:BandedLazyLayouts, <:AbstractLazyLayout}) = simplify(M)
copy(M::Mul{<:AbstractLazyLayout, <:BandedLazyLayouts}) = simplify(M)
copy(M::Mul{BroadcastLayout{typeof(*)}, <:BandedLazyLayouts}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
for op in (:*, :/, :\)
@eval begin
simplifiable(M::Mul{BroadcastLayout{typeof($op)}, <:BandedLazyLayouts}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
copy(M::Mul{BroadcastLayout{typeof($op)}, <:BandedLazyLayouts}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)

Check warning on line 558 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L557-L558

Added lines #L557 - L558 were not covered by tests
end
end
copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout}) = simplify(M)
copy(M::Mul{<:DiagonalLayout, <:BandedLazyLayouts}) = simplify(M)

Expand Down
41 changes: 36 additions & 5 deletions src/lazybroadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@
BroadcastStyle(::Type{<:Adjoint{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
BroadcastStyle(::Type{<:Transpose{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
BroadcastStyle(::Type{<:SubArray{<:Any,1,<:LazyMatrix,<:Tuple{Slice,Any}}}) = LazyArrayStyle{1}()

BroadcastStyle(::Type{<:UpperOrLowerTriangular{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
BroadcastStyle(::Type{<:LinearAlgebra.HermOrSym{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()

Check warning on line 142 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L141-L142

Added lines #L141 - L142 were not covered by tests


BroadcastStyle(L::LazyArrayStyle{N}, ::StructuredMatrixStyle) where N = L


Expand Down Expand Up @@ -397,11 +402,37 @@
###

_broadcast_mul_mul(A, B) = simplify(Mul(broadcast(*, A...), B))
_broadcast_mul_mul((a,B)::Tuple{AbstractVector,AbstractMatrix}, C) = a .* (B*C)
_broadcast_mul_mul((A,b)::Tuple{AbstractMatrix,AbstractVector}, C) = b .* (A*C)
@inline copy(M::Mul{BroadcastLayout{typeof(*)}}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:AbstractLazyLayout}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
@inline copy(M::Mul{BroadcastLayout{typeof(*)},ApplyLayout{typeof(*)}}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
_broadcast_mul_mul(::typeof(*), A, B) = _broadcast_mul_mul(A, B) # maintain back-compatibility with Quasi/ContiuumArrays.jl
_broadcast_mul_simplifiable(op, A, B) = Val(false)
_broadcast_mul_mul(op, A, B) = simplify(Mul(broadcast(op, A...), B))

Check warning on line 407 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L405-L407

Added lines #L405 - L407 were not covered by tests

for op in (:*, :\)
@eval begin
_broadcast_mul_simplifiable(::typeof($op), (a,B)::Tuple{Union{AbstractVector,Number},AbstractMatrix}, C) = simplifiable(*, B, C)
_broadcast_mul_mul(::typeof($op), (a,B)::Tuple{Union{AbstractVector,Number},AbstractMatrix}, C) = broadcast($op, a, (B*C))

Check warning on line 412 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L411-L412

Added lines #L411 - L412 were not covered by tests
end
end

for op in (:*, :/)
@eval begin
_broadcast_mul_simplifiable(::typeof($op), (A,b)::Tuple{AbstractMatrix,Union{AbstractVector,Number}}, C) = simplifiable(*, A, C)
_broadcast_mul_mul(::typeof($op), (A,b)::Tuple{AbstractMatrix,Union{AbstractVector,Number}}, C) = broadcast($op, (A*C), b)

Check warning on line 419 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L418-L419

Added lines #L418 - L419 were not covered by tests
end
end



for op in (:*, :/, :\)
@eval begin
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)}}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)},<:LazyLayouts}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)},ApplyLayout{typeof(*)}}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
@inline copy(M::Mul{BroadcastLayout{typeof($op)}}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
@inline copy(M::Mul{BroadcastLayout{typeof($op)},<:LazyLayouts}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
@inline copy(M::Mul{BroadcastLayout{typeof($op)},ApplyLayout{typeof(*)}}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)

Check warning on line 432 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L427-L432

Added lines #L427 - L432 were not covered by tests
end
end


for op in (:*, :\, :/)
@eval begin
Expand Down
9 changes: 6 additions & 3 deletions src/linalg/inv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,13 @@
simplifiable(::Mul{<:AbstractInvLayout}) = Val(true)

copy(M::Mul{<:AbstractInvLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
copy(M::Mul{<:AbstractInvLayout,<:AbstractLazyLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
@inline copy(M::Mul{<:AbstractInvLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
@inline copy(M::Mul{<:AbstractInvLayout,ApplyLayout{typeof(*)}}) = simplify(M)
copy(M::Mul{<:AbstractInvLayout, <:AbstractLazyLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
@inline copy(M::Mul{<:AbstractInvLayout, <:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
@inline copy(M::Mul{<:AbstractInvLayout, ApplyLayout{typeof(*)}}) = simplify(M)

Check warning on line 93 in src/linalg/inv.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/inv.jl#L91-L93

Added lines #L91 - L93 were not covered by tests
copy(L::Ldiv{<:AbstractInvLayout}) = pinv(L.A) * L.B
copy(L::Ldiv{<:AbstractInvLayout, <:AbstractLazyLayout}) = pinv(L.A) * L.B
copy(L::Ldiv{<:AbstractInvLayout, <:AbstractInvLayout}) = pinv(L.A) * L.B
copy(L::Ldiv{<:AbstractInvLayout, ApplyLayout{typeof(*)}}) = pinv(L.A) * L.B

Check warning on line 97 in src/linalg/inv.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/inv.jl#L95-L97

Added lines #L95 - L97 were not covered by tests
Ldiv(A::Applied{<:Any,typeof(\)}) = Ldiv(A.args...)


Expand Down
3 changes: 3 additions & 0 deletions src/padded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@
simplifiable(::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)


@inline simplifiable(M::Mul{BroadcastLayout{typeof(*)},<:Union{PaddedColumns,PaddedLayout}}) = simplifiable(Mul{BroadcastLayout{typeof(*)},UnknownLayout}(M.A,M.B))
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:Union{PaddedColumns,PaddedLayout}}) = copy(Mul{BroadcastLayout{typeof(*)},UnknownLayout}(M.A,M.B))

Check warning on line 514 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L513-L514

Added lines #L513 - L514 were not covered by tests

simplifiable(::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
copy(M::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
Expand Down
7 changes: 7 additions & 0 deletions test/bandedtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ LinearAlgebra.lmul!(β::Number, A::PseudoBandedMatrix) = (lmul!(β, A.data); A)
@test MemoryLayout(BroadcastMatrix(cos, A)) isa BroadcastLayout
end
end

@testset "broadcast_mul_mul" begin
A = BroadcastMatrix(*, randn(5,5), randn(5,5))
B = ApplyArray(*, brand(5,5,1,2), brand(5,5,2,1))
@test A * UpperTriangular(B) ≈ Matrix(A) * UpperTriangular(B)
@test simplifiable(*, A, UpperTriangular(B)) == Val(false) # TODO: probably should be true
end
end

@testset "Cache" begin
Expand Down
32 changes: 30 additions & 2 deletions test/broadcasttests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BroadcastTests

using LazyArrays, ArrayLayouts, LinearAlgebra, FillArrays, Base64, Test
using StaticArrays, Tracker
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle, sub_materialize
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle, sub_materialize, simplifiable
import Base: broadcasted

using ..InfiniteArrays
Expand Down Expand Up @@ -224,6 +224,9 @@ using Infinities
@test A[:,2] ≈ Ã[:,2] ≈ Matrix(A)[:,2]
@test C*b ≈ Matrix(C)*b

@test simplifiable(*, A, B) == Val(true)
@test simplifiable(*, Ã, B) == Val(true)

D = Diagonal(Fill(2,4))
@test A*D ≈ Matrix(A)*D
end
Expand Down Expand Up @@ -396,7 +399,7 @@ using Infinities
@test a[:,1:3] isa Adjoint{Int,Vector{Int}}
end

@testset "broadcast with adjtrans" begin
@testset "broadcast with adjtrans/triangular/hermsym" begin
a = BroadcastArray(real, ((1:5) .+ im))
b = BroadcastArray(exp, ((1:5) .+ im))
@test exp.(transpose(a)) isa Transpose{<:Any,<:BroadcastVector}
Expand All @@ -407,8 +410,17 @@ using Infinities
@test exp.(b') isa BroadcastMatrix
@test exp.(transpose(b)) == transpose(exp.(b))
@test exp.(b') == exp.(b)'

A = BroadcastArray(*, ((1:5) .+ im), (1:5)')
@test exp.(UpperTriangular(A)) isa BroadcastArray
@test exp.(Symmetric(A)) isa BroadcastArray
@test exp.(Hermitian(A)) isa BroadcastArray
@test exp.(UpperTriangular(A)) == exp.(UpperTriangular(Matrix(A)))
@test exp.(Symmetric(A)) == exp.(Symmetric(Matrix(A)))
@test exp.(Hermitian(A)) == exp.(Hermitian(Matrix(A)))
end


@testset "linear indexing" begin
a = BroadcastArray(real, ((1:5) .+ im))
b = BroadcastArray(exp, ((1:5) .+ im))
Expand All @@ -421,6 +433,22 @@ using Infinities
a = BroadcastArray(Base.literal_pow, ^, 1:5, Val(2))
@test last(a) == 25
end

@testset "BroadcastArray(*) * MulArray" begin
A = BroadcastArray(*, 1:3, randn(3,4))
B = ApplyArray(*, randn(4,3), randn(3,4))
@test A*B ≈ Matrix(A)*Matrix(B)
@test A*UpperTriangular(B) ≈ Matrix(A)*UpperTriangular(Matrix(B))
@test simplifiable(*,A,B) == Val(false) # TODO: Why False?
@test simplifiable(*,A,UpperTriangular(B)) == Val(false) # TODO: Why False?
end

@testset "/" begin
A = BroadcastArray(/, randn(3,4), randn(3,4))
B = randn(4,3)
@test A*B ≈ Matrix(A)*Matrix(B)
@test simplifiable(*,A,B) == Val(false) # TODO: Why False?
end
end

end #module
2 changes: 1 addition & 1 deletion test/cachetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ using Infinities
@testset "linalg" begin
c = cache(Fill(3,3,3))
@test fill(2,1,3) * c == fill(18,1,3)
@test ApplyMatrix(exp,fill(3,3,3)) * c == exp(fill(3,3,3)) * fill(3,3,3)
@test ApplyMatrix(exp,fill(3,3,3)) * c exp(fill(3,3,3)) * fill(3,3,3)
@test BroadcastMatrix(exp,fill(3,3,3)) * c == exp.(fill(3,3,3)) * fill(3,3,3)
@test fill(2,3)' * c == fill(18,1,3)
@test fill(2,3,1)' * c == fill(18,1,3)
Expand Down
8 changes: 8 additions & 0 deletions test/ldivtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,12 @@ end
@test rowsupport(invL, ()) == 1:0
end

@testset "Inv \\ Lazy" begin
A = randn(5,5)
Ai = InvMatrix(A)
@test Ai \ Ai ≈ I
@test Ai \ BroadcastArray(exp, A) ≈ Ai \ exp.(A) ≈ A*exp.(A)
@test Ai \ ApplyArray(*, A, A) ≈ A^3
end

end # module
7 changes: 7 additions & 0 deletions test/paddedtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,5 +430,12 @@ paddeddata(a::PaddedPadded) = a
@test_throws SingularException ArrayLayouts.ldiv!(Bidiagonal(-1:3, 1:4, :L), c)
@test_throws SingularException ArrayLayouts.ldiv!(Bidiagonal(-4:0, 1:4, :L), c)
end

@testset "Broadcast * Padded" begin
B = BroadcastArray(*, 1:8, (2:9)')
p = Vcat(1:2, Zeros(6))
@test B*p == Matrix(B)*p
@test simplifiable(*,B,p) == Val(true)
end
end
end # module
Loading