Skip to content

Commit 576463f

Browse files
authored
4-tensor Mul/InvPlan (#172)
* 4-tensor Mul/InvPlan * Update test_splines.jl
1 parent 1cb30aa commit 576463f

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.17"
3+
version = "0.17.1"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/plans.jl

+64-75
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,6 @@ InvPlan(fact, dims) = InvPlan((fact,), dims)
4242
size(F::InvPlan) = size.(F.factorizations, 1)
4343

4444

45-
function *(P::InvPlan{<:Any,<:Tuple,Int}, x::AbstractVector)
46-
@assert P.dims == 1
47-
only(P.factorizations) \ x # Only a single factorization when dims isa Int
48-
end
49-
50-
function *(P::InvPlan{<:Any,<:Tuple,Int}, X::AbstractMatrix)
51-
if P.dims == 1
52-
only(P.factorizations) \ X # Only a single factorization when dims isa Int
53-
else
54-
@assert P.dims == 2
55-
permutedims(only(P.factorizations) \ permutedims(X))
56-
end
57-
end
58-
59-
function *(P::InvPlan{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3})
60-
Y = similar(X)
61-
if P.dims == 1
62-
for j in axes(X,3)
63-
Y[:,:,j] = only(P.factorizations) \ X[:,:,j]
64-
end
65-
elseif P.dims == 2
66-
for k in axes(X,1)
67-
Y[k,:,:] = only(P.factorizations) \ X[k,:,:]
68-
end
69-
else
70-
@assert P.dims == 3
71-
for k in axes(X,1), j in axes(X,2)
72-
Y[k,j,:] = only(P.factorizations) \ X[k,j,:]
73-
end
74-
end
75-
Y
76-
end
77-
78-
function *(P::InvPlan, X::AbstractArray)
79-
for d in P.dims
80-
X = InvPlan(P.factorizations[d], d) * X
81-
end
82-
X
83-
end
84-
85-
8645
"""
8746
MulPlan(matrix, dims)
8847
@@ -96,44 +55,74 @@ end
9655
MulPlan(mats::Tuple, dims) = MulPlan{eltype(mats), typeof(mats), typeof(dims)}(mats, dims)
9756
MulPlan(mats::AbstractMatrix, dims) = MulPlan((mats,), dims)
9857

99-
function *(P::MulPlan{<:Any,<:Tuple,Int}, x::AbstractVector)
100-
@assert P.dims == 1
101-
only(P.matrices) * x
102-
end
103-
104-
function *(P::MulPlan{<:Any,<:Tuple,Int}, X::AbstractMatrix)
105-
if P.dims == 1
106-
only(P.matrices) * X
107-
else
108-
@assert P.dims == 2
109-
permutedims(only(P.matrices) * permutedims(X))
110-
end
111-
end
112-
113-
function *(P::MulPlan{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3})
114-
Y = similar(X)
115-
if P.dims == 1
116-
for j in axes(X,3)
117-
Y[:,:,j] = only(P.matrices) * X[:,:,j]
58+
for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizations)))
59+
@eval begin
60+
function *(P::$Pln{<:Any,<:Tuple,Int}, x::AbstractVector)
61+
@assert P.dims == 1
62+
$op(only(getfield(P, $fld)), x) # Only a single factorization when dims isa Int
11863
end
119-
elseif P.dims == 2
120-
for k in axes(X,1)
121-
Y[k,:,:] = only(P.matrices) * X[k,:,:]
64+
65+
function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractMatrix)
66+
if P.dims == 1
67+
$op(only(getfield(P, $fld)), X) # Only a single factorization when dims isa Int
68+
else
69+
@assert P.dims == 2
70+
permutedims($op(only(getfield(P, $fld)), permutedims(X)))
71+
end
12272
end
123-
else
124-
@assert P.dims == 3
125-
for k in axes(X,1), j in axes(X,2)
126-
Y[k,j,:] = only(P.matrices) * X[k,j,:]
73+
74+
function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3})
75+
Y = similar(X)
76+
if P.dims == 1
77+
for j in axes(X,3)
78+
Y[:,:,j] = $op(only(getfield(P, $fld)), X[:,:,j])
79+
end
80+
elseif P.dims == 2
81+
for k in axes(X,1)
82+
Y[k,:,:] = $op(only(getfield(P, $fld)), X[k,:,:])
83+
end
84+
else
85+
@assert P.dims == 3
86+
for k in axes(X,1), j in axes(X,2)
87+
Y[k,j,:] = $op(only(getfield(P, $fld)), X[k,j,:])
88+
end
89+
end
90+
Y
91+
end
92+
93+
function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,4})
94+
Y = similar(X)
95+
if P.dims == 1
96+
for j in axes(X,3), l in axes(X,4)
97+
Y[:,:,j,l] = $op(only(getfield(P, $fld)), X[:,:,j,l])
98+
end
99+
elseif P.dims == 2
100+
for k in axes(X,1), l in axes(X,4)
101+
Y[k,:,:,l] = $op(only(getfield(P, $fld)), X[k,:,:,l])
102+
end
103+
elseif P.dims == 3
104+
for k in axes(X,1), j in axes(X,2)
105+
Y[k,j,:,:] = $op(only(getfield(P, $fld)), X[k,j,:,:])
106+
end
107+
elseif P.dims == 4
108+
for k in axes(X,1), j in axes(X,2), l in axes(X,3)
109+
Y[k,j,l,:] = $op(only(getfield(P, $fld)), X[k,j,l,:])
110+
end
111+
end
112+
Y
113+
end
114+
115+
116+
117+
*(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray) = error("Overload")
118+
119+
function *(P::$Pln, X::AbstractArray)
120+
for (fac,dim) in zip(getfield(P, $fld), P.dims)
121+
X = $Pln(fac, dim) * X
122+
end
123+
X
127124
end
128125
end
129-
Y
130-
end
131-
132-
function *(P::MulPlan, X::AbstractArray)
133-
for d in P.dims
134-
X = MulPlan(P.matrices[d], d) * X
135-
end
136-
X
137126
end
138127

139128
*(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.dims)

test/test_splines.jl

+22
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,28 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
526526
X[k, j, :] = L[g,:] \ X[k, j, :]
527527
end
528528
@test PX X
529+
530+
n = size(L,2)
531+
X = randn(n, n, n, n)
532+
P = plan_transform(L, X)
533+
PX = P * X
534+
for k = 1:n, j = 1:n, l = 1:n
535+
X[:, k, j, l] = L[g,:] \ X[:, k, j, l]
536+
end
537+
for k = 1:n, j = 1:n, l = 1:n
538+
X[k, :, j, l] = L[g,:] \ X[k, :, j, l]
539+
end
540+
for k = 1:n, j = 1:n, l = 1:n
541+
X[k, j, :, l] = L[g,:] \ X[k, j, :, l]
542+
end
543+
for k = 1:n, j = 1:n, l = 1:n
544+
X[k, j, l, :] = L[g,:] \ X[k, j, l, :]
545+
end
546+
@test PX X
547+
548+
X = randn(n, n, n, n, n)
549+
P = plan_transform(L, X)
550+
@test_throws ErrorException P * X
529551
end
530552

531553
@testset "Mul coefficients" begin

0 commit comments

Comments
 (0)