1
+ """
2
+ Understanding requires_grad, retain_grad, Leaf, and Non-leaf tensors
3
+ ====================================================================
4
+
5
+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6
+
7
+ This tutorial explains the subtleties of ``requires_grad``,
8
+ ``retain_grad``, leaf, and non-leaf tensors using a simple example.
9
+
10
+ Before starting, make sure you understand `tensors and how to manipulate
11
+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12
+ A basic knowledge of `how autograd
13
+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14
+ would also be useful.
15
+
16
+ """
17
+
18
+
19
+ ######################################################################
20
+ # Setup
21
+ # -----
22
+ #
23
+ # First, make sure `PyTorch is
24
+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
25
+ # the necessary libraries.
26
+ #
27
+
28
+ import torch
29
+ import torch .nn as nn
30
+ import torch .optim as optim
31
+ import torch .nn .functional as F
32
+ import matplotlib .pyplot as plt
33
+
34
+
35
+ ######################################################################
36
+ # Next, we instantiate a simple network to focus on the gradients. This
37
+ # will be an affine layer, followed by a ReLU activation, and ending with
38
+ # a MSE loss between prediction and label tensors.
39
+ #
40
+ # .. math::
41
+ #
42
+ # \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
43
+ #
44
+ # .. math::
45
+ #
46
+ # L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
47
+ #
48
+ # Note that the ``requires_grad=True`` is necessary for the parameters
49
+ # (``W`` and ``b``) so that PyTorch tracks operations involving those
50
+ # tensors. We’ll discuss more about this in a future
51
+ # `section <#requires-grad>`__.
52
+ #
53
+
54
+ # tensor setup
55
+ x = torch .ones (1 , 3 ) # input with shape: (1, 3)
56
+ W = torch .ones (3 , 2 , requires_grad = True ) # weights with shape: (3, 2)
57
+ b = torch .ones (1 , 2 , requires_grad = True ) # bias with shape: (1, 2)
58
+ y = torch .ones (1 , 2 ) # output with shape: (1, 2)
59
+
60
+ # forward pass
61
+ z = (x @ W ) + b # pre-activation with shape: (1, 2)
62
+ y_pred = F .relu (z ) # activation with shape: (1, 2)
63
+ loss = F .mse_loss (y_pred , y ) # scalar loss
64
+
65
+
66
+ ######################################################################
67
+ # Leaf vs. non-leaf tensors
68
+ # -------------------------
69
+ #
70
+ # After running the forward pass, PyTorch autograd has built up a `dynamic
71
+ # computational
72
+ # graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
73
+ # which is shown below. This is a `Directed Acyclic Graph
74
+ # (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
75
+ # keeps a record of input tensors (leaf nodes), all subsequent operations
76
+ # on those tensors, and the intermediate/output tensors (non-leaf nodes).
77
+ # The graph is used to compute gradients for each tensor starting from the
78
+ # graph roots (outputs) to the leaves (inputs) using the `chain
79
+ # rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
80
+ #
81
+ # .. math::
82
+ #
83
+ # \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
84
+ #
85
+ # .. math::
86
+ #
87
+ # \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
88
+ # \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
89
+ # \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
90
+ # \cdots \cdot
91
+ # \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
92
+ #
93
+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
94
+ # :alt: Computational graph after forward pass
95
+ #
96
+ # Computational graph after forward pass
97
+ #
98
+ # PyTorch considers a node to be a *leaf* if it is not the result of a
99
+ # tensor operation with at least one input having ``requires_grad=True``
100
+ # (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
101
+ # *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
102
+ # programmatically by probing the ``is_leaf`` attribute of the tensors:
103
+ #
104
+
105
+ # prints True because new tensors are leafs by convention
106
+ print (f"{ x .is_leaf = } " )
107
+
108
+ # prints False because tensor is the result of an operation with at
109
+ # least one input having requires_grad=True
110
+ print (f"{ z .is_leaf = } " )
111
+
112
+
113
+ ######################################################################
114
+ # The distinction between leaf and non-leaf determines whether the
115
+ # tensor’s gradient will be stored in the ``grad`` property after the
116
+ # backward pass, and thus be usable for `gradient
117
+ # descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
118
+ # this some more in the `following section <#retain-grad>`__.
119
+ #
120
+ # Let’s now investigate how PyTorch calculates and stores gradients for
121
+ # the tensors in its computational graph.
122
+ #
123
+
124
+
125
+ ######################################################################
126
+ # ``requires_grad``
127
+ # -----------------
128
+ #
129
+ # To build the computational graph which can be used for gradient
130
+ # calculation, we need to pass in the ``requires_grad=True`` parameter to
131
+ # a tensor constructor. By default, the value is ``False``, and thus
132
+ # PyTorch does not track gradients on any created tensors. To verify this,
133
+ # try not setting ``requires_grad``, re-run the forward pass, and then run
134
+ # backpropagation. You will see:
135
+ #
136
+ # ::
137
+ #
138
+ # >>> loss.backward()
139
+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
140
+ #
141
+ # This error means that autograd can’t backpropagate to any leaf tensors
142
+ # because ``loss`` is not tracking gradients. If you need to change the
143
+ # property, you can call ``requires_grad_()`` on the tensor (notice the \_
144
+ # suffix).
145
+ #
146
+ # We can sanity check which nodes require gradient calculation, just like
147
+ # we did above with the ``is_leaf`` attribute:
148
+ #
149
+
150
+ print (f"{ x .requires_grad = } " ) # prints False because requires_grad=False by default
151
+ print (f"{ W .requires_grad = } " ) # prints True because we set requires_grad=True in constructor
152
+ print (f"{ z .requires_grad = } " ) # prints True because tensor is a non-leaf node
153
+
154
+
155
+ ######################################################################
156
+ # It’s useful to remember that a non-leaf tensor has
157
+ # ``requires_grad=True`` by definition, since backpropagation would fail
158
+ # otherwise. If the tensor is a leaf, then it will only have
159
+ # ``requires_grad=True`` if it was specifically set by the user. Another
160
+ # way to phrase this is that if at least one of the inputs to a tensor
161
+ # requires the gradient, then it will require the gradient as well.
162
+ #
163
+ # There are two exceptions to this rule:
164
+ #
165
+ # 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
166
+ # ``requires_grad=True`` for its parameters (see
167
+ # `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
168
+ # 2. Locally disabling gradient computation with context managers (see
169
+ # `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
170
+ #
171
+ # In summary, ``requires_grad`` tells autograd which tensors need to have
172
+ # their gradients calculated for backpropagation to work. This is
173
+ # different from which tensors have their ``grad`` field populated, which
174
+ # is the topic of the next section.
175
+ #
176
+
177
+
178
+ ######################################################################
179
+ # ``retain_grad``
180
+ # ---------------
181
+ #
182
+ # To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
183
+ # the backward pass so that we can extract the gradients.
184
+ #
185
+
186
+ loss .backward ()
187
+
188
+
189
+ ######################################################################
190
+ # Calling ``backward()`` populates the ``grad`` field of all leaf tensors
191
+ # which had ``requires_grad=True``. The ``grad`` is the gradient of the
192
+ # loss with respect to the tensor we are probing. Before running
193
+ # ``backward()``, this attribute is set to ``None``.
194
+ #
195
+
196
+ print (f"{ W .grad = } " )
197
+ print (f"{ b .grad = } " )
198
+
199
+
200
+ ######################################################################
201
+ # You might be wondering about the other tensors in our network. Let’s
202
+ # check the remaining leaf nodes:
203
+ #
204
+
205
+ # prints all None because requires_grad=False
206
+ print (f"{ x .grad = } " )
207
+ print (f"{ y .grad = } " )
208
+
209
+
210
+ ######################################################################
211
+ # The gradients for these tensors haven’t been populated because we did
212
+ # not explicitly tell PyTorch to calculate their gradient
213
+ # (``requires_grad=False``).
214
+ #
215
+ # Let’s now look at an intermediate non-leaf node:
216
+ #
217
+
218
+ print (f"{ z .grad = } " )
219
+
220
+
221
+ ######################################################################
222
+ # PyTorch returns ``None`` for the gradient and also warns us that a
223
+ # non-leaf node’s ``grad`` attribute is being accessed. Although autograd
224
+ # has to calculate intermediate gradients for backpropagation to work, it
225
+ # assumes you don’t need to access the values afterwards. To change this
226
+ # behavior, we can use the ``retain_grad()`` function on a tensor. This
227
+ # tells the autograd engine to populate that tensor’s ``grad`` after
228
+ # calling ``backward()``.
229
+ #
230
+
231
+ # we have to re-run the forward pass
232
+ z = (x @ W ) + b
233
+ y_pred = F .relu (z )
234
+ loss = F .mse_loss (y_pred , y )
235
+
236
+ # tell PyTorch to store the gradients after backward()
237
+ z .retain_grad ()
238
+ y_pred .retain_grad ()
239
+ loss .retain_grad ()
240
+
241
+ # have to zero out gradients otherwise they would accumulate
242
+ W .grad = None
243
+ b .grad = None
244
+
245
+ # backpropagation
246
+ loss .backward ()
247
+
248
+ # print gradients for all tensors that have requires_grad=True
249
+ print (f"{ W .grad = } " )
250
+ print (f"{ b .grad = } " )
251
+ print (f"{ z .grad = } " )
252
+ print (f"{ y_pred .grad = } " )
253
+ print (f"{ loss .grad = } " )
254
+
255
+
256
+ ######################################################################
257
+ # We get the same result for ``W.grad`` as before. Also note that because
258
+ # the loss is scalar, the gradient of the loss with respect to itself is
259
+ # simply ``1.0``.
260
+ #
261
+ # If we look at the state of the computational graph now, we see that the
262
+ # ``retains_grad`` attribute has changed for the intermediate tensors. By
263
+ # convention, this attribute will print ``False`` for any leaf node, even
264
+ # if it requires its gradient.
265
+ #
266
+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
267
+ # :alt: Computational graph after backward pass
268
+ #
269
+ # Computational graph after backward pass
270
+ #
271
+ # If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
272
+ # If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
273
+ # PyTorch actually throws an error, since it can’t store the gradient if
274
+ # it is never calculated.
275
+ #
276
+ # ::
277
+ #
278
+ # >>> x.retain_grad()
279
+ # RuntimeError: can't retain_grad on Tensor that has requires_grad=False
280
+ #
281
+
282
+
283
+ ######################################################################
284
+ # Summary table
285
+ # -------------
286
+ #
287
+ # Using ``retain_grad()`` and ``retains_grad`` only make sense for
288
+ # non-leaf nodes, since the ``grad`` attribute will already be populated
289
+ # for leaf tensors that have ``requires_grad=True``. By default, these
290
+ # non-leaf nodes do not retain (store) their gradient after
291
+ # backpropagation. We can change that by rerunning the forward pass,
292
+ # telling PyTorch to store the gradients, and then performing
293
+ # backpropagation.
294
+ #
295
+ # The following table can be used as a reference which summarizes the
296
+ # above discussions. The following scenarios are the only ones that are
297
+ # valid for PyTorch tensors.
298
+ #
299
+ #
300
+ #
301
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
302
+ # | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
303
+ # +================+========================+========================+===================================================+=====================================+
304
+ # | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
305
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
306
+ # | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
307
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
308
+ # | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
309
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
310
+ # | ``False`` | ``True`` | ``True`` | no-op | no-op |
311
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
312
+ #
313
+
314
+
315
+ ######################################################################
316
+ # Conclusion
317
+ # ----------
318
+ #
319
+ # In this tutorial, we covered when and how PyTorch computes gradients for
320
+ # leaf and non-leaf tensors. By using ``retain_grad``, we can access the
321
+ # gradients of intermediate tensors within autograd’s computational graph.
322
+ #
323
+ # If you would like to learn more about how PyTorch’s autograd system
324
+ # works, please visit the `references <#references>`__ below. If you have
325
+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
326
+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
327
+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
328
+ # reach out.
329
+ #
330
+
331
+
332
+ ######################################################################
333
+ # References
334
+ # ----------
335
+ #
336
+ # - `A Gentle Introduction to
337
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
338
+ # - `Automatic Differentiation with
339
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
340
+ # - `Autograd
341
+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
342
+ #
0 commit comments