@@ -210,3 +210,82 @@ No thread safety on C++ hooks
210
210
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211
211
to be correctly applied in multithreading environment, you will need to write
212
212
proper thread locking code to ensure the hooks are thread safe.
213
+
214
+ Autograd for Complex Numbers
215
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216
+
217
+ **What notion of complex derivative does PyTorch use? **
218
+ *******************************************************
219
+
220
+ PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
221
+ convention for autograd for Complex Numbers.
222
+
223
+ Suppose we have a function :math: `F: ℂ → ℂ` which we can decompose into functions u and v
224
+ which compute the real and imaginary parts of the function:
225
+
226
+ .. code ::
227
+
228
+ def F(z):
229
+ x, y = real(z), imag(z)
230
+ return u(x, y) + v(x, y) * 1j
231
+
232
+ where :math: `1 j` is a unit imaginary number.
233
+
234
+ We define the :math: `JVP` for function :math: `F` at :math: `(x, y)` applied to a tangent
235
+ vector :math: `c+dj \in C` as:
236
+
237
+ .. math :: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}
238
+
239
+ where
240
+
241
+ .. math ::
242
+ J = \begin {bmatrix}
243
+ \frac {\partial u(x, y)}{\partial x} & \frac {\partial u(x, y)}{\partial y}\\
244
+ \frac {\partial v(x, y)}{\partial x} & \frac {\partial v(x, y)}{\partial y} \end {bmatrix} \\
245
+
246
+ This is similar to the definition of the JVP for a function defined from :math: `R^2 → R^2 `, and the multiplication
247
+ with :math: `[1 , 1 j]^T` is used to identify the result as a complex number.
248
+
249
+ We define the :math: `VJP` of :math: `F` at :math: `(x, y)` for a cotangent vector :math: `c+dj \in C` as:
250
+
251
+ .. math :: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
252
+
253
+ In PyTorch, the `VJP ` is mostly what we care about, as it is the computation performed when we do backward
254
+ mode automatic differentiation. Notice that d and :math: `1 j` are negated in the formula above. Please look at
255
+ the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
256
+ to get explanation for the negative signs in the formula.
257
+
258
+ **What happens if I call backward() on a complex scalar? **
259
+ *******************************************************************************
260
+
261
+ The gradient for a complex function is computed assuming the input function is a holomorphic function.
262
+ This is because for general :math: `ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
263
+ (as in the `2x2 ` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
264
+ However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
265
+ Cauchy-Riemann equations that ensure that `2x2 ` Jacobians have the special form of a scale-and-rotate
266
+ matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
267
+ obtain that gradient using backward which is just a call to `vjp ` with covector `1.0 `.
268
+
269
+ The net effect of this assumption is that the partial derivatives of the imaginary part of the function
270
+ (:math: `v(x, y)` above) are discarded for :func: `torch.autograd.backward ` on a complex scalar
271
+ (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
272
+
273
+ For any other desired behavior, you can specify the covector `grad_output ` in :func: `torch.autograd.backward ` call accordingly.
274
+
275
+ **How are the JVP and VJP defined for cross-domain functions? **
276
+ ***************************************************************
277
+
278
+ Based on formulas above and the behavior we expect to see (going from :math: `ℂ → ℝ^2 → ℂ` should be an identity),
279
+ we use the formula given below for cross-domain functions.
280
+
281
+ The :math: `JVP` and :math: `VJP` for a :math: `f1 : ℂ → ℝ^2 ` are defined as:
282
+
283
+ .. math :: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}
284
+
285
+ .. math :: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
286
+
287
+ The :math: `JVP` and :math: `VJP` for a :math: `f1 : ℝ^2 → ℂ` are defined as:
288
+
289
+ .. math :: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\
290
+
291
+ .. math :: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J
0 commit comments