Skip to content

Commit 1cf16ae

Browse files
add region to generic plans
1 parent a603718 commit 1cf16ae

File tree

2 files changed

+117
-57
lines changed

2 files changed

+117
-57
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FastTransforms"
22
uuid = "057dd010-8810-581a-b7be-e3fc3b93f78c"
3-
version = "0.11.2"
3+
version = "0.11.3"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/fftBigFloat.jl

+116-56
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@ const ComplexFloats = Complex{T} where T<:AbstractFloat
88
# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
99
# To add more types, add them in the union of the function's signature.
1010

11+
function generic_fft(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
12+
region == 1 && (ret = generic_fft(x))
13+
ret
14+
end
15+
16+
function generic_fft!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
17+
region == 1 && (x[:] .= generic_fft(x))
18+
x
19+
end
20+
21+
function generic_fft(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
22+
region == 1:1 && (ret = generic_fft(x))
23+
ret
24+
end
25+
26+
function generic_fft!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
27+
region == 1:1 && (x[:] .= generic_fft(x))
28+
x
29+
end
30+
31+
function generic_fft(x::StridedMatrix{T}, region::Integer) where T<:AbstractFloats
32+
if region == 1
33+
ret = hcat([generic_fft(x[:, j]) for j in 1:size(x, 2)]...)
34+
end
35+
ret
36+
end
37+
38+
function generic_fft!(x::StridedMatrix{T}, region::Integer) where T<:AbstractFloats
39+
if region == 1
40+
for j in 1:size(x, 2)
41+
x[:, j] .= generic_fft(x[:, j])
42+
end
43+
end
44+
x
45+
end
46+
1147
function generic_fft(x::Vector{T}) where T<:AbstractFloats
1248
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
1349
n = length(x)
@@ -18,36 +54,20 @@ function generic_fft(x::Vector{T}) where T<:AbstractFloats
1854
return Wks.*conv(xq,wq)[n+1:2n]
1955
end
2056

57+
generic_bfft(x::StridedArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft(conj(x), region))
58+
generic_bfft!(x::StridedArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft!(conj!(x), region))
59+
generic_ifft(x::StridedArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(length(x), conj!(generic_fft(conj(x), region)))
60+
generic_ifft!(x::StridedArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(length(x), conj!(generic_fft!(conj!(x), region)))
2161

22-
function generic_fft!(x::Vector{T}) where T<:AbstractFloats
23-
x[:] = generic_fft(x)
24-
return x
25-
end
26-
27-
# add rfft for AbstractFloat, by calling fft
28-
generic_rfft(v::Vector{T}) where T<:AbstractFloats = generic_fft(v)[1:div(length(v),2)+1]
29-
30-
function generic_irfft(v::Vector{T}, n::Integer) where T<:ComplexFloats
62+
generic_rfft(v::Vector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1]
63+
function generic_irfft(v::Vector{T}, n::Integer, region) where T<:ComplexFloats
3164
@assert n==2length(v)-1
3265
r = Vector{T}(undef, n)
3366
r[1:length(v)]=v
3467
r[length(v)+1:end]=reverse(conj(v[2:end]))
35-
real(generic_ifft(r))
36-
end
37-
38-
generic_bfft(x::Vector{T}) where {T <: AbstractFloats} = conj!(generic_fft(conj(x)))
39-
function generic_bfft!(x::Vector{T}) where {T <: AbstractFloats}
40-
x[:] = generic_bfft(x)
41-
return x
42-
end
43-
44-
generic_brfft(v::Vector, n::Integer) = generic_irfft(v, n)*n
45-
46-
generic_ifft(x::Vector{T}) where {T<:AbstractFloats} = conj!(generic_fft(conj(x)))/length(x)
47-
function generic_ifft!(x::Vector{T}) where T<:AbstractFloats
48-
x[:] = generic_ifft(x)
49-
return x
68+
real(generic_ifft(r, region))
5069
end
70+
generic_brfft(v::StridedArray, n::Integer, region) = generic_irfft(v, n, region)*n
5171

5272
function conv(u::StridedVector{T}, v::StridedVector{T}) where T<:AbstractFloats
5373
nu,nv = length(u),length(v)
@@ -112,6 +132,46 @@ function generic_ifft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
112132
return complex.(y[1:2:end],-y[2:2:end])/length(x)
113133
end
114134

135+
function generic_dct(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
136+
region == 1 && (ret = generic_dct(x))
137+
ret
138+
end
139+
140+
function generic_dct!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
141+
region == 1 && (x[:] .= generic_dct(x))
142+
x
143+
end
144+
145+
function generic_idct(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
146+
region == 1 && (ret = generic_idct(x))
147+
ret
148+
end
149+
150+
function generic_idct!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
151+
region == 1 && (x[:] .= generic_idct(x))
152+
x
153+
end
154+
155+
function generic_dct(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
156+
region == 1:1 && (ret = generic_dct(x))
157+
ret
158+
end
159+
160+
function generic_dct!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
161+
region == 1:1 && (x[:] .= generic_dct(x))
162+
x
163+
end
164+
165+
function generic_idct(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
166+
region == 1:1 && (ret = generic_idct(x))
167+
ret
168+
end
169+
170+
function generic_idct!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
171+
region == 1:1 && (x[:] .= generic_idct(x))
172+
x
173+
end
174+
115175
function generic_dct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
116176
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
117177
N = length(a)
@@ -139,8 +199,6 @@ end
139199

140200
generic_idct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_idct(complex(a)))
141201

142-
generic_dct!(a::AbstractArray{T}) where {T<:AbstractFloats} = (b = generic_dct(a); a[:] = b)
143-
generic_idct!(a::AbstractArray{T}) where {T<:AbstractFloats} = (b = generic_idct(a); a[:] = b)
144202

145203
# These lines mimick the corresponding ones in FFTW/src/dct.jl, but with
146204
# AbstractFloat rather than fftwNumber.
@@ -157,33 +215,35 @@ abstract type DummyPlan{T} <: Plan{T} end
157215
for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiDCTPlan)
158216
# All plans need an initially undefined pinv field
159217
@eval begin
160-
mutable struct $P{T,inplace} <: DummyPlan{T}
218+
mutable struct $P{T,inplace,G} <: DummyPlan{T}
219+
region::G # region (iterable) of dims that are transformed
161220
pinv::DummyPlan{T}
162-
$P{T,inplace}() where {T<:AbstractFloats, inplace} = new()
221+
$P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region)
163222
end
164223
end
165224
end
166225
for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
167226
@eval begin
168-
mutable struct $P{T,inplace} <: DummyPlan{T}
227+
mutable struct $P{T,inplace,G} <: DummyPlan{T}
169228
n::Integer
229+
region::G # region (iterable) of dims that are transformed
170230
pinv::DummyPlan{T}
171-
$P{T,inplace}(n::Integer) where {T<:AbstractFloats, inplace} = new(n)
231+
$P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region)
172232
end
173233
end
174234
end
175235

176236
for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
177237
(:DummyDCTPlan,:DummyiDCTPlan))
178238
@eval begin
179-
plan_inv(::$Plan{T,inplace}) where {T,inplace} = $iPlan{T,inplace}()
180-
plan_inv(::$iPlan{T,inplace}) where {T,inplace} = $Plan{T,inplace}()
239+
plan_inv(p::$Plan{T,inplace,G}) where {T,inplace,G} = $iPlan{T,inplace,G}(p.region)
240+
plan_inv(p::$iPlan{T,inplace,G}) where {T,inplace,G} = $Plan{T,inplace,G}(p.region)
181241
end
182242
end
183243

184244
# Specific for rfft, irfft and brfft:
185-
plan_inv(p::DummyirFFTPlan{T,inplace}) where {T,inplace} = DummyrFFTPlan{T,Inplace}(p.n)
186-
plan_inv(p::DummyrFFTPlan{T,inplace}) where {T,inplace} = DummyirFFTPlan{T,Inplace}(p.n)
245+
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,Inplace,G}(p.n, p.region)
246+
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,Inplace,G}(p.n, p.region)
187247

188248

189249

@@ -194,26 +254,26 @@ for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
194254
(:DummyDCTPlan,:generic_dct,:generic_dct!),
195255
(:DummyiDCTPlan,:generic_idct,:generic_idct!))
196256
@eval begin
197-
*(p::$Plan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff!(x)
198-
*(p::$Plan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff(x)
257+
*(p::$Plan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff!(x, p.region)
258+
*(p::$Plan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff(x, p.region)
199259
function mul!(C::StridedVector, p::$Plan, x::StridedVector)
200-
C[:] = $ff(x)
260+
C[:] = $ff(x, p.region)
201261
C
202262
end
203263
end
204264
end
205265

206266
# Specific for irfft and brfft:
207-
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n)
208-
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n)
267+
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n, p.region)
268+
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n, p.region)
209269
function mul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector)
210-
C[:] = generic_irfft(x, p.n)
270+
C[:] = generic_irfft(x, p.n, p.region)
211271
C
212272
end
213-
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n)
214-
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n)
273+
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n, p.region)
274+
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n, p.region)
215275
function mul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector)
216-
C[:] = generic_brfft(x, p.n)
276+
C[:] = generic_brfft(x, p.n, p.region)
217277
C
218278
end
219279

@@ -233,27 +293,27 @@ AbstractFFTs._fftfloat(::Type{T}) where {T <: AbstractFloat} = T
233293
# This is the reason for using StridedArray below. We also have to carefully
234294
# distinguish between real and complex arguments.
235295

236-
plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false}()
237-
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true}()
296+
plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false,typeof(region)}(region)
297+
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true,typeof(region)}(region)
238298

239-
plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false}()
240-
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true}()
299+
plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false,typeof(region)}(region)
300+
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true,typeof(region)}(region)
241301

242302
# The ifft plans are automatically provided in terms of the bfft plans above.
243-
# plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false}()
244-
# plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true}()
303+
# plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false,typeof(region)}(region)
304+
# plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true,typeof(region)}(region)
245305

246-
plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false}()
247-
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true}()
306+
plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false,typeof(region)}(region)
307+
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true,typeof(region)}(region)
248308

249-
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false}()
250-
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true}()
309+
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region)
310+
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region)
251311

252-
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false}(length(x))
253-
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false}(n)
312+
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false,typeof(region)}(length(x), region)
313+
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
254314

255315
# A plan for irfft is created in terms of a plan for brfft.
256-
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false}(n)
316+
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
257317

258318
# These don't exist for now:
259319
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()

0 commit comments

Comments
 (0)