Skip to content

Commit 7d34bf2

Browse files
authored
Correct inverse plan logic (#69)
* Correct inverse plan caching logic in test plans * Use inv rather than plan_inv in scaled plan
1 parent 3e7d412 commit 7d34bf2

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

src/definitions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ plan_ifft(x::AbstractArray, region; kws...) =
278278
plan_ifft!(x::AbstractArray, region; kws...) =
279279
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))
280280

281-
plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
281+
plan_inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale))
282282

283283
LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
284284
LinearAlgebra.lmul!(p.scale, LinearAlgebra.mul!(y, p.p, x))

test/testplans.jl

+24-24
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,20 @@ end
2727
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return InverseTestPlan{T}(region, size(x))
2929
end
30+
3031
function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
3132
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
32-
unscaled_pinv.pinv = p
33-
pinv = AbstractFFTs.ScaledPlan(
34-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
35-
)
33+
N = AbstractFFTs.normalization(T, p.sz, p.region)
34+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
35+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
3636
return pinv
3737
end
38-
function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T}
39-
unscaled_pinv = TestPlan{T}(p.region, p.sz)
40-
unscaled_pinv.pinv = p
41-
pinv = AbstractFFTs.ScaledPlan(
42-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
43-
)
44-
return pinv
38+
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
39+
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
40+
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
41+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
42+
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
43+
return p
4544
end
4645

4746
# Just a helper function since forward and backward are nearly identical
@@ -118,22 +117,23 @@ function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
118117
firstdim = first(p.region)::Int
119118
d = p.sz[firstdim]
120119
sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N))
120+
_N = AbstractFFTs.normalization(T, p.sz, p.region)
121+
121122
unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz)
122-
unscaled_pinv.pinv = p
123-
pinv = AbstractFFTs.ScaledPlan(
124-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
125-
)
123+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
124+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
126125
return pinv
127126
end
128-
function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N}
129-
firstdim = first(p.region)::Int
130-
sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N))
131-
unscaled_pinv = TestRPlan{T}(p.region, sz)
132-
unscaled_pinv.pinv = p
133-
pinv = AbstractFFTs.ScaledPlan(
134-
unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region),
135-
)
136-
return pinv
127+
128+
function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N}
129+
firstdim = first(pinv.region)::Int
130+
sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N))
131+
_N = AbstractFFTs.normalization(T, sz, pinv.region)
132+
133+
unscaled_p = TestRPlan{T}(pinv.region, sz)
134+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
135+
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
136+
return p
137137
end
138138

139139
Base.size(p::TestRPlan) = p.sz

0 commit comments

Comments
 (0)