Skip to content

Commit e0e667f

Browse files
committed
1 parent a25656d commit e0e667f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

Diff for: ext/AbstractFFTsChainRulesCoreExt.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3030
halfdim = first(dims)
3131
d = size(x, halfdim)
3232
n = size(y, halfdim)
33-
scale = reshape(
33+
scale = typeof(y)(reshape(
3434
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
3535
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
36-
)
36+
))
3737

3838
project_x = ChainRulesCore.ProjectTo(x)
3939
function rfft_pullback(ȳ)
@@ -72,10 +72,10 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7272
n = size(x, halfdim)
7373
invN = AbstractFFTs.normalization(y, dims)
7474
twoinvN = 2 * invN
75-
scale = reshape(
75+
scale = typeof(y)(reshape(
7676
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
7777
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
78-
)
78+
))
7979

8080
project_x = ChainRulesCore.ProjectTo(x)
8181
function irfft_pullback(ȳ)
@@ -111,10 +111,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
111111
# compute scaling factors
112112
halfdim = first(dims)
113113
n = size(x, halfdim)
114-
scale = reshape(
114+
scale = typeof(y)(reshape(
115115
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
116116
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
117-
)
117+
))
118118

119119
project_x = ChainRulesCore.ProjectTo(x)
120120
function brfft_pullback(ȳ)

0 commit comments

Comments
 (0)