Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rules for FFTs #541

Closed
wants to merge 16 commits into from
Closed
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.33"
version = "0.11"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Expand All @@ -16,24 +17,27 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
AbstractFFTs = "0.5, 1"
Calculus = "0.2, 0.3, 0.4, 0.5"
CommonSubexpressions = "0.3"
DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1"
DiffRules = "1.4.0"
DiffTests = "0.0.1, 0.1"
FFTW = "1"
LogExpFunctions = "0.3"
NaNMath = "0.2.2, 0.3, 1"
Preferences = "1"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
julia = "1"
julia = "1.6"

[extras]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"]
test = ["Calculus", "DiffTests", "FFTW", "SparseArrays", "Test", "InteractiveUtils"]
3 changes: 3 additions & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ include("gradient.jl")
include("jacobian.jl")
include("hessian.jl")

import AbstractFFTs
include("fft.jl")

export DiffResults

end # module
2 changes: 2 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ end
@inline tagtype(::Type{V}) where {V} = Nothing
@inline tagtype(::Dual{T,V,N}) where {T,V,N} = T
@inline tagtype(::Type{Dual{T,V,N}}) where {T,V,N} = T
@inline tagtype(::Complex{T}) where T = tagtype(T)
@inline tagtype(::Type{Complex{T}}) where T = tagtype(T)

####################################
# N-ary Operation Definition Tools #
Expand Down
84 changes: 84 additions & 0 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

value(x::Complex{<:Dual}) =
Complex(x.re.value, x.im.value)

partials(x::Complex{<:Dual}, n::Int) =
Complex(partials(x.re, n), partials(x.im, n))

npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N
npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N

# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im)
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im

AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)

for plan in [:plan_fft, :plan_ifft, :plan_bfft]
@eval begin

AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

end
end

# rfft only accepts real arrays
AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.plan_rfft(value.(x), region)

for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex?
@eval begin

AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), d, region)

end
end

# for f in (:dct, :idct)
# pf = Symbol("plan_", f)
# @eval begin
# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x
# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x
# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...)
# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...)
# end
# end


for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities
@eval begin

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) =
_apply_plan(p, x)

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) =
_apply_plan(p, x)

end
end

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
xtil = p * value.(x)
dxtils = ntuple(npartials(eltype(x))) do n
p * partials.(x, n)
end
__apply_plan(tagtype(eltype(x)), xtil, dxtils)
end

function __apply_plan(T, xtil, dxtils)
map(xtil, dxtils...) do val, parts...
Complex(
Dual{T}(real(val), map(real, parts)),
Dual{T}(imag(val), map(imag, parts)),
)
end
end
1 change: 1 addition & 0 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct TestTag end
struct OuterTestTag end

samerng() = MersenneTwister(1)
Random.seed!(132)

# By lower-bounding the Int range at 2, we avoid cases where differentiating an
# exponentiation of an Int value would cause a DomainError due to reducing the
Expand Down
32 changes: 32 additions & 0 deletions test/FFTTest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module FFTTest

using Test
using ForwardDiff: Dual, valtype, value, partials, derivative
using FFTW
using AbstractFFTs: complexfloat, realfloat


x1 = Dual.(1:4.0, 2:5, 3:6)

@test value.(x1) == 1:4
@test partials.(x1, 1) == 2:5

@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)

@test fft(x1, 1)[1] isa Complex{<:Dual}

@testset "$f" for f in [fft, ifft, rfft, bfft]
@test value.(f(x1)) == f(value.(x1))
@test partials.(f(x1), 1) == f(partials.(x1, 1))
end

f = x -> real(fft([x; 0; 0])[1])
@test derivative(f,0.1) ≈ 1

r = x -> real(rfft([x; 0; 0])[1])
@test derivative(r,0.1) ≈ 1

# c = x -> dct([x; 0; 0])[1]
# @test derivative(c,0.1) ≈ 1
end # module
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ Random.seed!(SEED)
end
end
println("##### Running all ForwardDiff tests took $(time() - t0) seconds.")
end
end