Skip to content

Commit a395e09

Browse files
anjali411malfet
authored andcommitted
Autograd Doc for Complex Numbers (pytorch#41012)
Summary: Pull Request resolved: pytorch#41012 Test Plan: Imported from OSS Differential Revision: D22476911 Pulled By: anjali411 fbshipit-source-id: 7da20cb4312a0465272bebe053520d9911475828
1 parent 2ca5543 commit a395e09

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

docs/source/notes/autograd.rst

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,82 @@ No thread safety on C++ hooks
210210
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211
to be correctly applied in multithreading environment, you will need to write
212212
proper thread locking code to ensure the hooks are thread safe.
213+
214+
Autograd for Complex Numbers
215+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+
**What notion of complex derivative does PyTorch use?**
218+
*******************************************************
219+
220+
PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
221+
convention for autograd for Complex Numbers.
222+
223+
Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
224+
which compute the real and imaginary parts of the function:
225+
226+
.. code::
227+
228+
def F(z):
229+
x, y = real(z), imag(z)
230+
return u(x, y) + v(x, y) * 1j
231+
232+
where :math:`1j` is a unit imaginary number.
233+
234+
We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent
235+
vector :math:`c+dj \in C` as:
236+
237+
.. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}
238+
239+
where
240+
241+
.. math::
242+
J = \begin{bmatrix}
243+
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
244+
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\
245+
246+
This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication
247+
with :math:`[1, 1j]^T` is used to identify the result as a complex number.
248+
249+
We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as:
250+
251+
.. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
252+
253+
In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward
254+
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at
255+
the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
256+
to get explanation for the negative signs in the formula.
257+
258+
**What happens if I call backward() on a complex scalar?**
259+
*******************************************************************************
260+
261+
The gradient for a complex function is computed assuming the input function is a holomorphic function.
262+
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
263+
(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
264+
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
265+
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
266+
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
267+
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.
268+
269+
The net effect of this assumption is that the partial derivatives of the imaginary part of the function
270+
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
271+
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
272+
273+
For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.
274+
275+
**How are the JVP and VJP defined for cross-domain functions?**
276+
***************************************************************
277+
278+
Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
279+
we use the formula given below for cross-domain functions.
280+
281+
The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as:
282+
283+
.. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}
284+
285+
.. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
286+
287+
The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as:
288+
289+
.. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\
290+
291+
.. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J

0 commit comments

Comments
 (0)