@@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
30
30
halfdim = first (dims)
31
31
d = size (x, halfdim)
32
32
n = size (y, halfdim)
33
- scale = reshape (
33
+ scale = typeof (y)( reshape (
34
34
[i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
35
35
ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
36
- )
36
+ ))
37
37
38
38
project_x = ChainRulesCore. ProjectTo (x)
39
39
function rfft_pullback (ȳ)
@@ -72,10 +72,10 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
72
72
n = size (x, halfdim)
73
73
invN = AbstractFFTs. normalization (y, dims)
74
74
twoinvN = 2 * invN
75
- scale = reshape (
75
+ scale = typeof (y)( reshape (
76
76
[i == 1 || (i == n && 2 * (i - 1 ) == d) ? invN : twoinvN for i in 1 : n],
77
77
ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
78
- )
78
+ ))
79
79
80
80
project_x = ChainRulesCore. ProjectTo (x)
81
81
function irfft_pullback (ȳ)
@@ -111,10 +111,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
111
111
# compute scaling factors
112
112
halfdim = first (dims)
113
113
n = size (x, halfdim)
114
- scale = reshape (
114
+ scale = typeof (y)( reshape (
115
115
[i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
116
116
ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
117
- )
117
+ ))
118
118
119
119
project_x = ChainRulesCore. ProjectTo (x)
120
120
function brfft_pullback (ȳ)
0 commit comments