Skip to content

Commit 6e7d8fd

Browse files
authored
Expand functions in P'(x .^2 .* P) (#191)
* Expand functions in P'(x .^2 .* P) * Update ci.yml * add tests and use expand * Delete Project.toml * Update bases.jl * add coverage * test adj broadcast * Add tests * Update test_chebyshev.jl * Update test_chebyshev.jl * add one mroe test
1 parent 09105e5 commit 6e7d8fd

File tree

5 files changed

+67
-6
lines changed

5 files changed

+67
-6
lines changed

.github/workflows/ci.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
version:
23-
- '1.10'
23+
- 'lts'
24+
- '1'
2425
os:
2526
- ubuntu-latest
2627
- macOS-latest

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Infinities = "0.1"
3939
IntervalSets = "0.7"
4040
LazyArrays = "2"
4141
Makie = "0.20, 0.21"
42-
QuasiArrays = "0.11.5"
42+
QuasiArrays = "0.11.8"
4343
RecipesBase = "1.0"
4444
StaticArrays = "1.0"
4545
julia = "1.10"

src/bases/bases.jl

+42-2
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,54 @@ function _broadcast_mul_ldiv(::Tuple{ScalarLayout,ApplyLayout{typeof(*)}}, A, B)
145145
a * (A \ b)
146146
end
147147

148-
_broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) =
149-
_broadcast_mul_ldiv((ScalarLayout(),UnknownLayout()), A, B)
148+
_broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) = _broadcast_mul_ldiv((ScalarLayout(),UnknownLayout()), A, B)
150149
_broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
151150

152151
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
153152
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
154153

155154

156155

156+
# multiplication operators, reexpand in basis A
157+
@inline function _broadcast_mul_adj(::Tuple{Any,AbstractBasisLayout}, Ac, B)
158+
a,b = arguments(B)
159+
@assert a isa AbstractQuasiVector # Only works for vec .* mat
160+
A = Ac'
161+
ab = (A * (A \ a)) .* b # broadcasted should be overloaded
162+
MemoryLayout(ab) isa BroadcastLayout && return Ac*transform_ldiv(A, ab)
163+
Ac*ab
164+
end
165+
166+
@inline function _broadcast_mul_adj(::Tuple{Any,ApplyLayout{typeof(*)}}, Ac, B)
167+
a,b = arguments(B)
168+
@assert a isa AbstractQuasiVector # Only works for vec .* mat
169+
args = arguments(*, b)
170+
*(Ac*(a .* first(args)), tail(args)...)
171+
end
172+
173+
174+
function _broadcast_mul_adj(::Tuple{ScalarLayout,Any}, Ac, B)
175+
a,b = arguments(B)
176+
a * (Ac*b)
177+
end
178+
179+
function _broadcast_mul_adj(::Tuple{ScalarLayout,ApplyLayout{typeof(*)}}, Ac, B)
180+
a,b = arguments(B)
181+
a * (Ac*b)
182+
end
183+
184+
_broadcast_mul_adj(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) = _broadcast_mul_adj((ScalarLayout(),UnknownLayout()), A, B)
185+
_broadcast_mul_adj(_, A, B) = copy(Mul{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
186+
187+
_broadcast_mul_adj_simplifiable(_, ::AbstractBasisLayout) = Val(true)
188+
_broadcast_mul_adj_simplifiable(_, ::ApplyLayout{typeof(*)}) = Val(true)
189+
_broadcast_mul_adj_simplifiable(::ScalarLayout, _) = Val(true)
190+
_broadcast_mul_adj_simplifiable(::ScalarLayout, ::ApplyLayout{typeof(*)}) = Val(true)
191+
_broadcast_mul_adj_simplifiable(::ScalarLayout, ::AbstractBasisLayout) = Val(true)
192+
_broadcast_mul_adj_simplifiable(_, _) = Val(false)
193+
194+
simplifiable(L::Mul{<:AdjointBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_adj_simplifiable(map(MemoryLayout,arguments(L.B))...)
195+
copy(L::Mul{<:AdjointBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_adj(map(MemoryLayout,arguments(L.B)), L.A, L.B)
157196

158197

159198
"""
@@ -651,6 +690,7 @@ diff_layout(::ExpansionLayout, A, dims...) = diff_layout(ApplyLayout{typeof(*)}(
651690
####
652691

653692
simplifiable(::Mul{<:AdjointBasisLayout, <:AbstractBasisLayout}) = Val(true)
693+
@inline simplifiable(L::Mul{<:AdjointBasisLayout,ApplyLayout{typeof(*)}}) = simplifiable(*, L.A, first(arguments(*, L.B)))
654694
function copy(M::Mul{<:AdjointBasisLayout, <:AbstractBasisLayout})
655695
A = (M.A)'
656696
A == M.B && return grammatrix(A)

test/test_chebyshev.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ContinuumArrays, LinearAlgebra, FastTransforms, QuasiArrays, ArrayLayouts, Base64, LazyArrays, Test
1+
using ContinuumArrays, LinearAlgebra, QuasiArrays, ArrayLayouts, Base64, LazyArrays, Test
2+
using FastTransforms
23
import ContinuumArrays: Basis, Weight, Map, LazyQuasiArrayStyle, TransformFactorization,
34
ExpansionLayout, checkpoints, MappedBasisLayout, MappedWeightedBasisLayout,
45
SubWeightedBasisLayout, WeightedBasisLayout, WeightLayout, basis, grammatrix
@@ -159,6 +160,25 @@ Base.:(==)(::FooBasis, ::FooBasis) = true
159160

160161
= T * (T \ a)
161162
@test T \ (ã .* ã) [1.5,1,0.5,0,0]
163+
164+
@test T'*(a .* T) isa Matrix
165+
@test T'*(a .* (T * (T \ a))) isa Vector
166+
@test_broken T'f isa Vector
167+
@test T'isa Vector
168+
@test T'*(ã .* ã) isa Vector
169+
@test (2T)'*(a .* T) isa Matrix
170+
@test T'*(2T) isa Matrix
171+
@test T'*(2T*randn(5)) isa Vector
172+
@test (2T)'*(T*(1:5)) T'*(2T*(1:5)) T'BroadcastQuasiMatrix(*, 2, T*(1:5))
173+
@test T' * (a .* (T * (1:5))) T' * ((a .* T) * (1:5))
174+
@test T'BroadcastQuasiMatrix(*, 2, 2T) == 4*(T'T)
175+
176+
@test LazyArrays.simplifiable(*, T', T*(1:5)) == Val(true)
177+
@test LazyArrays.simplifiable(*, T', (a .* (T * (1:5)))) == Val(true)
178+
@test LazyArrays.simplifiable(*, T', a .* T) == Val(true)
179+
@test LazyArrays.simplifiable(*, T', 2T) == Val(true)
180+
@test LazyArrays.simplifiable(*, T', BroadcastQuasiMatrix(*, 2, T*(1:5))) == Val(true)
181+
@test LazyArrays.simplifiable(*, T', BroadcastQuasiMatrix(*, 2, 2T)) == Val(true)
162182
end
163183

164184
@testset "sum/dot/diff" begin

test/test_splines.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
362362
@test Δ == -(*(B',D',D,B))
363363
@test Δ == -(B'D'D*B)
364364
@test Δ == -((B'D')*(D*B))
365-
@test_broken Δ == -B'*(D'D)*B
365+
@test Δ == -B'*(D'D)*B
366366
@test Δ == -(B'*(D'D)*B)
367367

368368
f = L*exp.(L.points) # project exp(x)

0 commit comments

Comments
 (0)