1+ """
2+ Understanding requires_grad, retain_grad, Leaf, and Non-leaf tensors
3+ ====================================================================
4+
5+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+ This tutorial explains the subtleties of ``requires_grad``,
8+ ``retain_grad``, leaf, and non-leaf tensors using a simple example.
9+
10+ Before starting, make sure you understand `tensors and how to manipulate
11+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12+ A basic knowledge of `how autograd
13+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14+ would also be useful.
15+
16+ """
17+
18+
19+ ######################################################################
20+ # Setup
21+ # -----
22+ #
23+ # First, make sure `PyTorch is
24+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
25+ # the necessary libraries.
26+ #
27+
28+ import torch
29+ import torch .nn as nn
30+ import torch .optim as optim
31+ import torch .nn .functional as F
32+ import matplotlib .pyplot as plt
33+
34+
35+ ######################################################################
36+ # Next, we instantiate a simple network to focus on the gradients. This
37+ # will be an affine layer, followed by a ReLU activation, and ending with
38+ # a MSE loss between prediction and label tensors.
39+ #
40+ # .. math::
41+ #
42+ # \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
43+ #
44+ # .. math::
45+ #
46+ # L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
47+ #
48+ # Note that the ``requires_grad=True`` is necessary for the parameters
49+ # (``W`` and ``b``) so that PyTorch tracks operations involving those
50+ # tensors. We’ll discuss more about this in a future
51+ # `section <#requires-grad>`__.
52+ #
53+
54+ # tensor setup
55+ x = torch .ones (1 , 3 ) # input with shape: (1, 3)
56+ W = torch .ones (3 , 2 , requires_grad = True ) # weights with shape: (3, 2)
57+ b = torch .ones (1 , 2 , requires_grad = True ) # bias with shape: (1, 2)
58+ y = torch .ones (1 , 2 ) # output with shape: (1, 2)
59+
60+ # forward pass
61+ z = (x @ W ) + b # pre-activation with shape: (1, 2)
62+ y_pred = F .relu (z ) # activation with shape: (1, 2)
63+ loss = F .mse_loss (y_pred , y ) # scalar loss
64+
65+
66+ ######################################################################
67+ # Leaf vs. non-leaf tensors
68+ # -------------------------
69+ #
70+ # After running the forward pass, PyTorch autograd has built up a `dynamic
71+ # computational
72+ # graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
73+ # which is shown below. This is a `Directed Acyclic Graph
74+ # (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
75+ # keeps a record of input tensors (leaf nodes), all subsequent operations
76+ # on those tensors, and the intermediate/output tensors (non-leaf nodes).
77+ # The graph is used to compute gradients for each tensor starting from the
78+ # graph roots (outputs) to the leaves (inputs) using the `chain
79+ # rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
80+ #
81+ # .. math::
82+ #
83+ # \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
84+ #
85+ # .. math::
86+ #
87+ # \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
88+ # \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
89+ # \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
90+ # \cdots \cdot
91+ # \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
92+ #
93+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
94+ # :alt: Computational graph after forward pass
95+ #
96+ # Computational graph after forward pass
97+ #
98+ # PyTorch considers a node to be a *leaf* if it is not the result of a
99+ # tensor operation with at least one input having ``requires_grad=True``
100+ # (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
101+ # *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
102+ # programmatically by probing the ``is_leaf`` attribute of the tensors:
103+ #
104+
105+ # prints True because new tensors are leafs by convention
106+ print (f"{ x .is_leaf = } " )
107+
108+ # prints False because tensor is the result of an operation with at
109+ # least one input having requires_grad=True
110+ print (f"{ z .is_leaf = } " )
111+
112+
113+ ######################################################################
114+ # The distinction between leaf and non-leaf determines whether the
115+ # tensor’s gradient will be stored in the ``grad`` property after the
116+ # backward pass, and thus be usable for `gradient
117+ # descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
118+ # this some more in the `following section <#retain-grad>`__.
119+ #
120+ # Let’s now investigate how PyTorch calculates and stores gradients for
121+ # the tensors in its computational graph.
122+ #
123+
124+
125+ ######################################################################
126+ # ``requires_grad``
127+ # -----------------
128+ #
129+ # To build the computational graph which can be used for gradient
130+ # calculation, we need to pass in the ``requires_grad=True`` parameter to
131+ # a tensor constructor. By default, the value is ``False``, and thus
132+ # PyTorch does not track gradients on any created tensors. To verify this,
133+ # try not setting ``requires_grad``, re-run the forward pass, and then run
134+ # backpropagation. You will see:
135+ #
136+ # ::
137+ #
138+ # >>> loss.backward()
139+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
140+ #
141+ # This error means that autograd can’t backpropagate to any leaf tensors
142+ # because ``loss`` is not tracking gradients. If you need to change the
143+ # property, you can call ``requires_grad_()`` on the tensor (notice the \_
144+ # suffix).
145+ #
146+ # We can sanity check which nodes require gradient calculation, just like
147+ # we did above with the ``is_leaf`` attribute:
148+ #
149+
150+ print (f"{ x .requires_grad = } " ) # prints False because requires_grad=False by default
151+ print (f"{ W .requires_grad = } " ) # prints True because we set requires_grad=True in constructor
152+ print (f"{ z .requires_grad = } " ) # prints True because tensor is a non-leaf node
153+
154+
155+ ######################################################################
156+ # It’s useful to remember that a non-leaf tensor has
157+ # ``requires_grad=True`` by definition, since backpropagation would fail
158+ # otherwise. If the tensor is a leaf, then it will only have
159+ # ``requires_grad=True`` if it was specifically set by the user. Another
160+ # way to phrase this is that if at least one of the inputs to a tensor
161+ # requires the gradient, then it will require the gradient as well.
162+ #
163+ # There are two exceptions to this rule:
164+ #
165+ # 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
166+ # ``requires_grad=True`` for its parameters (see
167+ # `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
168+ # 2. Locally disabling gradient computation with context managers (see
169+ # `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
170+ #
171+ # In summary, ``requires_grad`` tells autograd which tensors need to have
172+ # their gradients calculated for backpropagation to work. This is
173+ # different from which tensors have their ``grad`` field populated, which
174+ # is the topic of the next section.
175+ #
176+
177+
178+ ######################################################################
179+ # ``retain_grad``
180+ # ---------------
181+ #
182+ # To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
183+ # the backward pass so that we can extract the gradients.
184+ #
185+
186+ loss .backward ()
187+
188+
189+ ######################################################################
190+ # Calling ``backward()`` populates the ``grad`` field of all leaf tensors
191+ # which had ``requires_grad=True``. The ``grad`` is the gradient of the
192+ # loss with respect to the tensor we are probing. Before running
193+ # ``backward()``, this attribute is set to ``None``.
194+ #
195+
196+ print (f"{ W .grad = } " )
197+ print (f"{ b .grad = } " )
198+
199+
200+ ######################################################################
201+ # You might be wondering about the other tensors in our network. Let’s
202+ # check the remaining leaf nodes:
203+ #
204+
205+ # prints all None because requires_grad=False
206+ print (f"{ x .grad = } " )
207+ print (f"{ y .grad = } " )
208+
209+
210+ ######################################################################
211+ # The gradients for these tensors haven’t been populated because we did
212+ # not explicitly tell PyTorch to calculate their gradient
213+ # (``requires_grad=False``).
214+ #
215+ # Let’s now look at an intermediate non-leaf node:
216+ #
217+
218+ print (f"{ z .grad = } " )
219+
220+
221+ ######################################################################
222+ # PyTorch returns ``None`` for the gradient and also warns us that a
223+ # non-leaf node’s ``grad`` attribute is being accessed. Although autograd
224+ # has to calculate intermediate gradients for backpropagation to work, it
225+ # assumes you don’t need to access the values afterwards. To change this
226+ # behavior, we can use the ``retain_grad()`` function on a tensor. This
227+ # tells the autograd engine to populate that tensor’s ``grad`` after
228+ # calling ``backward()``.
229+ #
230+
231+ # we have to re-run the forward pass
232+ z = (x @ W ) + b
233+ y_pred = F .relu (z )
234+ loss = F .mse_loss (y_pred , y )
235+
236+ # tell PyTorch to store the gradients after backward()
237+ z .retain_grad ()
238+ y_pred .retain_grad ()
239+ loss .retain_grad ()
240+
241+ # have to zero out gradients otherwise they would accumulate
242+ W .grad = None
243+ b .grad = None
244+
245+ # backpropagation
246+ loss .backward ()
247+
248+ # print gradients for all tensors that have requires_grad=True
249+ print (f"{ W .grad = } " )
250+ print (f"{ b .grad = } " )
251+ print (f"{ z .grad = } " )
252+ print (f"{ y_pred .grad = } " )
253+ print (f"{ loss .grad = } " )
254+
255+
256+ ######################################################################
257+ # We get the same result for ``W.grad`` as before. Also note that because
258+ # the loss is scalar, the gradient of the loss with respect to itself is
259+ # simply ``1.0``.
260+ #
261+ # If we look at the state of the computational graph now, we see that the
262+ # ``retains_grad`` attribute has changed for the intermediate tensors. By
263+ # convention, this attribute will print ``False`` for any leaf node, even
264+ # if it requires its gradient.
265+ #
266+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
267+ # :alt: Computational graph after backward pass
268+ #
269+ # Computational graph after backward pass
270+ #
271+ # If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
272+ # If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
273+ # PyTorch actually throws an error, since it can’t store the gradient if
274+ # it is never calculated.
275+ #
276+ # ::
277+ #
278+ # >>> x.retain_grad()
279+ # RuntimeError: can't retain_grad on Tensor that has requires_grad=False
280+ #
281+
282+
283+ ######################################################################
284+ # Summary table
285+ # -------------
286+ #
287+ # Using ``retain_grad()`` and ``retains_grad`` only make sense for
288+ # non-leaf nodes, since the ``grad`` attribute will already be populated
289+ # for leaf tensors that have ``requires_grad=True``. By default, these
290+ # non-leaf nodes do not retain (store) their gradient after
291+ # backpropagation. We can change that by rerunning the forward pass,
292+ # telling PyTorch to store the gradients, and then performing
293+ # backpropagation.
294+ #
295+ # The following table can be used as a reference which summarizes the
296+ # above discussions. The following scenarios are the only ones that are
297+ # valid for PyTorch tensors.
298+ #
299+ #
300+ #
301+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
302+ # | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
303+ # +================+========================+========================+===================================================+=====================================+
304+ # | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
305+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
306+ # | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
307+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
308+ # | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
309+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
310+ # | ``False`` | ``True`` | ``True`` | no-op | no-op |
311+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
312+ #
313+
314+
315+ ######################################################################
316+ # Conclusion
317+ # ----------
318+ #
319+ # In this tutorial, we covered when and how PyTorch computes gradients for
320+ # leaf and non-leaf tensors. By using ``retain_grad``, we can access the
321+ # gradients of intermediate tensors within autograd’s computational graph.
322+ #
323+ # If you would like to learn more about how PyTorch’s autograd system
324+ # works, please visit the `references <#references>`__ below. If you have
325+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
326+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
327+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
328+ # reach out.
329+ #
330+
331+
332+ ######################################################################
333+ # References
334+ # ----------
335+ #
336+ # - `A Gentle Introduction to
337+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
338+ # - `Automatic Differentiation with
339+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
340+ # - `Autograd
341+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
342+ #
0 commit comments