Skip to content

Commit 83b3731

Browse files
authored
Merge pull request #2800 from stan-dev/feature/fft2-adjoints
Add 2d FFT adjoints
2 parents bdb3c86 + 669f352 commit 83b3731

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

stan/math/prim/fun/fft.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ inline Eigen::Matrix<scalar_type_t<V>, -1, 1> inv_fft(const V& y) {
8282
* @param[in] x matrix to transform
8383
* @return discrete 2D Fourier transform of `x`
8484
*/
85-
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr>
85+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
86+
require_not_var_t<base_type_t<value_type_t<M>>>* = nullptr>
8687
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> fft2(const M& x) {
8788
Eigen::Matrix<scalar_type_t<M>, -1, -1> y(x.rows(), x.cols());
8889
for (int i = 0; i < y.rows(); ++i)
@@ -103,7 +104,8 @@ inline Eigen::Matrix<scalar_type_t<M>, -1, -1> fft2(const M& x) {
103104
* @param[in] y matrix to inverse trnasform
104105
* @return inverse discrete 2D Fourier transform of `y`
105106
*/
106-
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr>
107+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
108+
require_not_var_t<base_type_t<value_type_t<M>>>* = nullptr>
107109
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> inv_fft2(const M& y) {
108110
Eigen::Matrix<scalar_type_t<M>, -1, -1> x(y.rows(), y.cols());
109111
for (int j = 0; j < x.cols(); ++j)

stan/math/rev/fun/fft.hpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,70 @@ inline plain_type_t<V> inv_fft(const V& y) {
103103
return plain_type_t<V>(res);
104104
}
105105

106+
/**
107+
* Return the two-dimensional discrete Fourier transform of the
108+
* specified complex matrix. The 2D discrete Fourier transform first
109+
* runs the discrete Fourier transform on the each row, then on each
110+
* column of the result.
111+
*
112+
* The adjoint computation is given by
113+
* ```
114+
* adjoint(x) += size(y) * inv_fft2(adjoint(y))
115+
* ```
116+
*
117+
* @tparam M type of complex matrix argument
118+
* @param[in] x matrix to transform
119+
* @return discrete 2D Fourier transform of `x`
120+
*/
121+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
122+
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
123+
inline plain_type_t<M> fft2(const M& x) {
124+
arena_t<M> arena_v = x;
125+
arena_t<M> res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
126+
127+
reverse_pass_callback([arena_v, res]() mutable {
128+
auto adj_inv_fft = inv_fft2(to_complex(res.real().adj(), res.imag().adj()));
129+
adj_inv_fft *= res.size();
130+
arena_v.real().adj() += adj_inv_fft.real();
131+
arena_v.imag().adj() += adj_inv_fft.imag();
132+
});
133+
134+
return plain_type_t<M>(res);
135+
}
136+
137+
/**
138+
* Return the two-dimensional inverse discrete Fourier transform of
139+
* the specified complex matrix. The 2D inverse discrete Fourier
140+
* transform first runs the 1D inverse Fourier transform on the
141+
* columns, and then on the resulting rows. The composition of the
142+
* FFT and inverse FFT (or vice-versa) is the identity.
143+
*
144+
* The adjoint computation is given by
145+
* ```
146+
* adjoint(y) += (1 / size(x)) * fft2(adjoint(x))
147+
* ```
148+
*
149+
* @tparam M type of complex matrix argument
150+
* @param[in] y matrix to inverse trnasform
151+
* @return inverse discrete 2D Fourier transform of `y`
152+
*/
153+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
154+
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
155+
inline plain_type_t<M> inv_fft2(const M& y) {
156+
arena_t<M> arena_v = y;
157+
arena_t<M> res
158+
= inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
159+
160+
reverse_pass_callback([arena_v, res]() mutable {
161+
auto adj_fft = fft2(to_complex(res.real().adj(), res.imag().adj()));
162+
adj_fft /= res.size();
163+
164+
arena_v.real().adj() += adj_fft.real();
165+
arena_v.imag().adj() += adj_fft.imag();
166+
});
167+
return plain_type_t<M>(res);
168+
}
169+
106170
} // namespace math
107171
} // namespace stan
108172
#endif

0 commit comments

Comments
 (0)