@@ -103,6 +103,70 @@ inline plain_type_t<V> inv_fft(const V& y) {
103
103
return plain_type_t <V>(res);
104
104
}
105
105
106
+ /* *
107
+ * Return the two-dimensional discrete Fourier transform of the
108
+ * specified complex matrix. The 2D discrete Fourier transform first
109
+ * runs the discrete Fourier transform on the each row, then on each
110
+ * column of the result.
111
+ *
112
+ * The adjoint computation is given by
113
+ * ```
114
+ * adjoint(x) += size(y) * inv_fft2(adjoint(y))
115
+ * ```
116
+ *
117
+ * @tparam M type of complex matrix argument
118
+ * @param[in] x matrix to transform
119
+ * @return discrete 2D Fourier transform of `x`
120
+ */
121
+ template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr ,
122
+ require_var_t <base_type_t <value_type_t <M>>>* = nullptr >
123
+ inline plain_type_t <M> fft2 (const M& x) {
124
+ arena_t <M> arena_v = x;
125
+ arena_t <M> res = fft2 (to_complex (arena_v.real ().val (), arena_v.imag ().val ()));
126
+
127
+ reverse_pass_callback ([arena_v, res]() mutable {
128
+ auto adj_inv_fft = inv_fft2 (to_complex (res.real ().adj (), res.imag ().adj ()));
129
+ adj_inv_fft *= res.size ();
130
+ arena_v.real ().adj () += adj_inv_fft.real ();
131
+ arena_v.imag ().adj () += adj_inv_fft.imag ();
132
+ });
133
+
134
+ return plain_type_t <M>(res);
135
+ }
136
+
137
+ /* *
138
+ * Return the two-dimensional inverse discrete Fourier transform of
139
+ * the specified complex matrix. The 2D inverse discrete Fourier
140
+ * transform first runs the 1D inverse Fourier transform on the
141
+ * columns, and then on the resulting rows. The composition of the
142
+ * FFT and inverse FFT (or vice-versa) is the identity.
143
+ *
144
+ * The adjoint computation is given by
145
+ * ```
146
+ * adjoint(y) += (1 / size(x)) * fft2(adjoint(x))
147
+ * ```
148
+ *
149
+ * @tparam M type of complex matrix argument
150
+ * @param[in] y matrix to inverse trnasform
151
+ * @return inverse discrete 2D Fourier transform of `y`
152
+ */
153
+ template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr ,
154
+ require_var_t <base_type_t <value_type_t <M>>>* = nullptr >
155
+ inline plain_type_t <M> inv_fft2 (const M& y) {
156
+ arena_t <M> arena_v = y;
157
+ arena_t <M> res
158
+ = inv_fft2 (to_complex (arena_v.real ().val (), arena_v.imag ().val ()));
159
+
160
+ reverse_pass_callback ([arena_v, res]() mutable {
161
+ auto adj_fft = fft2 (to_complex (res.real ().adj (), res.imag ().adj ()));
162
+ adj_fft /= res.size ();
163
+
164
+ arena_v.real ().adj () += adj_fft.real ();
165
+ arena_v.imag ().adj () += adj_fft.imag ();
166
+ });
167
+ return plain_type_t <M>(res);
168
+ }
169
+
106
170
} // namespace math
107
171
} // namespace stan
108
172
#endif
0 commit comments