Skip to content

Commit 860e18a

Browse files
anjali411malfet
authored andcommitted
Update torch.set_default_dtype doc (pytorch#41263)
Summary: Pull Request resolved: pytorch#41263 Test Plan: Imported from OSS Differential Revision: D22482989 Pulled By: anjali411 fbshipit-source-id: 2aadfbb84bbab66f3111970734a37ba74d817ffd
1 parent 8f804ba commit 860e18a

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torch/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,22 +270,30 @@ def set_default_tensor_type(t):
270270

271271

272272
def set_default_dtype(d):
273-
r"""Sets the default floating point dtype to :attr:`d`. This type will be
274-
used as default floating point type for type inference in
275-
:func:`torch.tensor`.
273+
r"""Sets the default floating point dtype to :attr:`d`.
274+
This dtype is:
275+
1. The inferred dtype for python floats in :func:`torch.tensor`.
276+
2. Used to infer dtype for python complex numbers. The default complex dtype is set to
277+
``torch.complex128`` if default floating point dtype is ``torch.float64``,
278+
otherwise it's set to ``torch.complex64``
276279
277280
The default floating point dtype is initially ``torch.float32``.
278281
279282
Args:
280283
d (:class:`torch.dtype`): the floating point dtype to make the default
281284
282285
Example::
283-
284-
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
286+
>>> # initial default for floating point is torch.float32
287+
>>> torch.tensor([1.2, 3]).dtype
285288
torch.float32
289+
>>> # initial default for floating point is torch.complex64
290+
>>> torch.tensor([1.2, 3j]).dtype
291+
torch.complex64
286292
>>> torch.set_default_dtype(torch.float64)
287-
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
293+
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
288294
torch.float64
295+
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
296+
torch.complex128
289297
290298
"""
291299
_C._set_default_dtype(d)

0 commit comments

Comments
 (0)