Skip to content

Commit b493372

Browse files
Merge pull request #252 from ioannisPApapadopoulos/jp/arrays
Extend FTPlans to Arrays and higher dimensions
2 parents d8ccedd + 9aca321 commit b493372

File tree

5 files changed

+153
-1
lines changed

5 files changed

+153
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FastTransforms"
22
uuid = "057dd010-8810-581a-b7be-e3fc3b93f78c"
3-
version = "0.16.4"
3+
version = "0.16.5"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/FastTransforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ for f in (:jac2jac,
128128
@eval $f(x::AbstractArray, y...; z...) = $lib_f(x, y...; z...)
129129
end
130130

131+
include("arrays.jl")
131132
# following use Toeplitz-Hankel to avoid expensive plans
132133
# for f in (:leg2cheb, :cheb2leg, :ultra2ultra)
133134
# th_f = Symbol("th_", f)

src/arrays.jl

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
struct ArrayPlan{T, FF<:FTPlan{<:T}, Szs<:Tuple, Dims<:Tuple{<:Int}} <: Plan{T}
2+
F::FF
3+
szs::Szs
4+
dims::Dims
5+
end
6+
size(P::ArrayPlan) = P.szs
7+
8+
function ArrayPlan(F::FTPlan{<:T}, c::AbstractArray{T}, dims::Tuple{<:Int}=(1,)) where T
9+
szs = size(c)
10+
@assert F.n == szs[dims[1]]
11+
ArrayPlan(F, size(c), dims)
12+
end
13+
14+
function *(P::ArrayPlan, f::AbstractArray)
15+
F, dims, szs = P.F, P.dims, P.szs
16+
@assert length(dims) == 1
17+
@assert szs == size(f)
18+
d = first(dims)
19+
20+
perm = (d, ntuple(i-> i + (i >= d), ndims(f) -1)...)
21+
fp = permutedims(f, perm)
22+
23+
fr = reshape(fp, size(fp,1), :)
24+
25+
permutedims(reshape(F*fr, size(fp)...), invperm(perm))
26+
end
27+
28+
function \(P::ArrayPlan, f::AbstractArray)
29+
F, dims, szs = P.F, P.dims, P.szs
30+
@assert length(dims) == 1
31+
@assert szs == size(f)
32+
d = first(dims)
33+
34+
perm = (d, ntuple(i-> i + (i >= d), ndims(f) -1)...)
35+
fp = permutedims(f, perm)
36+
37+
fr = reshape(fp, size(fp,1), :)
38+
39+
permutedims(reshape(F\fr, size(fp)...), invperm(perm))
40+
end
41+
42+
struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Szs<:Tuple, Dims<:Tuple} <: Plan{T}
43+
F::FF
44+
szs::Szs
45+
dims::Dims
46+
function NDimsPlan(F, szs, dims)
47+
if length(Set(szs[[dims...]])) > 1
48+
error("Different size in dims axes not yet implemented in N-dimensional transform.")
49+
end
50+
new{eltype(F), typeof(F), typeof(szs), typeof(dims)}(F, szs, dims)
51+
end
52+
end
53+
54+
size(P::NDimsPlan) = P.szs
55+
56+
function NDimsPlan(F::FTPlan, szs::Tuple, dims::Tuple)
57+
NDimsPlan(ArrayPlan(F, szs, (first(dims),)), szs, dims)
58+
end
59+
60+
function *(P::NDimsPlan, f::AbstractArray)
61+
F, dims = P.F, P.dims
62+
@assert size(P) == size(f)
63+
g = copy(f)
64+
t = 1:ndims(g)
65+
d1 = dims[1]
66+
for d in dims
67+
perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g))
68+
gp = permutedims(g, perm)
69+
g = permutedims(F*gp, invperm(perm))
70+
end
71+
return g
72+
end
73+
74+
function \(P::NDimsPlan, f::AbstractArray)
75+
F, dims = P.F, P.dims
76+
@assert size(P) == size(f)
77+
g = copy(f)
78+
t = 1:ndims(g)
79+
d1 = dims[1]
80+
for d in dims
81+
perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g))
82+
gp = permutedims(g, perm)
83+
g = permutedims(F\gp, invperm(perm))
84+
end
85+
return g
86+
end

test/arraystests.jl

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using FastTransforms, Test
2+
import FastTransforms: ArrayPlan, NDimsPlan
3+
4+
@testset "Array transform" begin
5+
@testset "ArrayPlan" begin
6+
c = randn(5,20,10)
7+
F = plan_cheb2leg(c)
8+
FT = ArrayPlan(F, c)
9+
10+
@test size(FT) == size(c)
11+
12+
f = similar(c);
13+
for k in axes(c,3)
14+
f[:,:,k] = (F*c[:,:,k])
15+
end
16+
@test f FT*c
17+
@test c FT\f
18+
19+
F = plan_cheb2leg(Vector{Float64}(axes(c,2)))
20+
FT = ArrayPlan(F, c, (2,))
21+
for k in axes(c,3)
22+
f[:,:,k] = (F*c[:,:,k]')'
23+
end
24+
@test f FT*c
25+
@test c FT\f
26+
end
27+
28+
@testset "NDimsPlan" begin
29+
c = randn(20,10,20)
30+
@test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), size(c), (1,2))
31+
32+
c = randn(5,20)
33+
F = plan_cheb2leg(c)
34+
FT = ArrayPlan(F, c)
35+
P = NDimsPlan(F, size(c), (1,))
36+
@test F*c FT*c P*c
37+
38+
c = randn(20,20,5);
39+
F = plan_cheb2leg(c)
40+
FT = ArrayPlan(F, c)
41+
P = NDimsPlan(FT, size(c), (1,2))
42+
43+
@test size(P) == size(c)
44+
45+
f = similar(c);
46+
for k in axes(f,3)
47+
f[:,:,k] = (F*(F*c[:,:,k])')'
48+
end
49+
@test f P*c
50+
@test c P\f
51+
52+
c = randn(5,10,10,60)
53+
F = plan_cheb2leg(randn(10))
54+
P = NDimsPlan(F, size(c), (2,3))
55+
f = similar(c)
56+
for i in axes(f,1), j in axes(f,4)
57+
f[i,:,:,j] = (F*(F*c[i,:,:,j])')'
58+
end
59+
@test f P*c
60+
@test c P\f
61+
end
62+
end
63+
64+

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ include("clenshawtests.jl")
1212
include("toeplitzplanstests.jl")
1313
include("toeplitzhankeltests.jl")
1414
include("symmetrictoeplitzplushankeltests.jl")
15+
include("arraystests.jl")

0 commit comments

Comments
 (0)