Skip to content

Commit 8f804ba

Browse files
anjali411malfet
authored andcommitted
Doc note for complex (pytorch#41252)
Summary: Pull Request resolved: pytorch#41252 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D22553266 Pulled By: anjali411 fbshipit-source-id: f6dc409da048496d72b29b0976dfd3dd6645bc4d
1 parent a395e09 commit 8f804ba

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

docs/source/complex_numbers.rst

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
.. _complex_numbers-doc:
2+
3+
Complex Numbers
4+
===============
5+
6+
Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers,
7+
and *j* is a solution of the equation :math:`x^2 = −1`. Complex numbers frequently occur in mathematics and
8+
engineering, especially in signal processing. Traditionally many users and libraries (e.g., TorchAudio) have
9+
handled complex numbers by representing the data in float tensors with shape :math:`(..., 2)` where the last
10+
dimension contains the real and imaginary values.
11+
12+
Tensors of complex dtypes provide a more natural user experience for working with complex numbers. Operations on
13+
complex tensors (e.g., :func:`torch.mv`, :func:`torch.matmul`) are likely to be faster and more memory efficient
14+
than operations on float tensors mimicking them. Operations involving complex numbers in PyTorch are optimized
15+
to use vectorized assembly instructions and specialized kernels (e.g. LAPACK, cuBlas).
16+
17+
.. note::
18+
Spectral operations (e.g., :func:`torch.fft`, :func:`torch.stft` etc.) currently don't use complex tensors but
19+
the API will be soon updated to use complex tensors.
20+
21+
.. warning ::
22+
Complex tensors is a beta feature and subject to change.
23+
24+
Creating Complex Tensors
25+
------------------------
26+
27+
We support two complex dtypes: `torch.cfloat` and `torch.cdouble`
28+
29+
::
30+
31+
>>> x = torch.randn(2,2, dtype=torch.cfloat)
32+
>>> x
33+
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
34+
[ 0.7706+0.1421j, 1.2110+0.1918j]])
35+
36+
.. note::
37+
38+
The default dtype for complex tensors is determined by the default floating point dtype.
39+
If the default floating point dtype is `torch.float64` then complex numbers are inferred to
40+
have a dtype of `torch.complex128`, otherwise they are assumed to have a dtype of `torch.complex64`.
41+
42+
All factory functions apart from :func:`torch.linspace`, :func:`torch.logspace`, and :func:`torch.arange` are
43+
supported for complex tensors.
44+
45+
Transition from the old representation
46+
--------------------------------------
47+
48+
Users who currently worked around the lack of complex tensors with real tensors of shape :math:`(..., 2)`
49+
can easily to switch using the complex tensors in their code using :func:`torch.view_as_complex`
50+
and :func:`torch.view_as_real`. Note that these functions don’t perform any copy and return a
51+
view of the input tensor.
52+
53+
::
54+
55+
>>> x = torch.randn(3, 2)
56+
>>> x
57+
tensor([[ 0.6125, -0.1681],
58+
[-0.3773, 1.3487],
59+
[-0.0861, -0.7981]])
60+
>>> y = torch.view_as_complex(x)
61+
>>> y
62+
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
63+
>>> torch.view_as_real(y)
64+
tensor([[ 0.6125, -0.1681],
65+
[-0.3773, 1.3487],
66+
[-0.0861, -0.7981]])
67+
68+
Accessing real and imag
69+
-----------------------
70+
71+
The real and imaginary values of a complex tensor can be accessed using the :attr:`real` and
72+
:attr:`imag`.
73+
74+
.. note::
75+
Accessing `real` and `imag` attributes doesn't allocate any memory, and in-place updates on the
76+
`real` and `imag` tensors will update the original complex tensor. Also, the
77+
returned `real` and `imag` tensors are not contiguous.
78+
79+
::
80+
81+
>>> y.real
82+
tensor([ 0.6125, -0.3773, -0.0861])
83+
>>> y.imag
84+
tensor([-0.1681, 1.3487, -0.7981])
85+
86+
>>> y.real.mul_(2)
87+
tensor([ 1.2250, -0.7546, -0.1722])
88+
>>> y
89+
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
90+
>>> y.real.stride()
91+
(2,)
92+
93+
Angle and abs
94+
-------------
95+
96+
The angle and absolute values of a complex tensor can be computed using :func:`torch.angle` and
97+
`torch.abs`.
98+
99+
::
100+
101+
>>> x1=torch.tensor([3j, 4+4j])
102+
>>> x1.abs()
103+
tensor([3.0000, 5.6569])
104+
>>> x1.angle()
105+
tensor([1.5708, 0.7854])
106+
107+
Linear Algebra
108+
--------------
109+
110+
Currently, there is very minimal linear algebra operation support for complex tensors.
111+
We currently support :func:`torch.mv`, :func:`torch.svd`, :func:`torch.qr`, and :func:`torch.inverse`
112+
(the latter three are only supported on CPU). However we are working to add support for more
113+
functions soon: :func:`torch.matmul`, :func:`torch.solve`, :func:`torch.eig`,
114+
:func:`torch.symeig`. If any of these would help your use case, please
115+
`search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_
116+
if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
117+
118+
119+
Serialization
120+
-------------
121+
122+
Complex tensors can be serialized, allowing data to be saved as complex values.
123+
124+
::
125+
126+
>>> torch.save(y, 'complex_tensor.pt')
127+
>>> torch.load('complex_tensor.pt')
128+
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
129+
130+
131+
Autograd
132+
--------
133+
134+
PyTorch supports autograd for complex tensors. The autograd APIs can be
135+
used for both holomorphic and non-holomorphic functions. For holomorphic functions,
136+
you get the regular complex gradient. For :math:`C → R` real-valued loss functions,
137+
`grad.conj()` gives a descent direction. For more details, check out the note :ref:`complex_autograd-doc`.
138+
139+
We do not support the following subsystems:
140+
141+
* Quantization
142+
143+
* JIT
144+
145+
* Sparse Tensors
146+
147+
* Distributed
148+
149+
If any of these would help your use case, please `search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_
150+
if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
4545
nn.init
4646
onnx
4747
optim
48+
complex_numbers
4849
quantization
4950
rpc
5051
torch.random <random>

docs/source/notes/autograd.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211
to be correctly applied in multithreading environment, you will need to write
212212
proper thread locking code to ensure the hooks are thread safe.
213213

214+
.. _complex_autograd-doc:
215+
214216
Autograd for Complex Numbers
215217
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216218

0 commit comments

Comments
 (0)