1
1
"""
2
- Extension points in `` nn.Module`` for ``load_state_dict`` and tensor subclasses
2
+ nn.Moduleμμ ``load_state_dict`` λ° ν
μ μλΈν΄λμ€μ νμ₯ ν¬μΈνΈ
3
3
===============================================================================
4
- **Author :** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
4
+ **μ μ :** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
5
5
6
- This recipe introduces a new utility function ``torch.utils.swap_tensors``
7
- as well as two new extension points where it has been integrated in
6
+ μ΄ λ μνΌλ μλ‘μ΄ μ νΈλ¦¬ν° ν¨μ ``torch.utils.swap_tensors``
7
+ λΏλ§ μλλΌ μ΄λ₯Ό ν΅ν©ν λ κ°μ§ μλ‘μ΄ νμ₯ μ§μ μ μκ°ν©λλ€
8
8
``nn.Module``:
9
9
10
- * ``nn.Module.to()`` and related methods
10
+ * ``nn.Module.to()`` λ° κ΄λ ¨ λ©μλ
11
11
* ``nn.Module.load_state_dict()``
12
12
13
- .. note ::
14
- This recipe requires PyTorch 2.3.0 or later .
13
+ .. μ£Όμ ::
14
+ μ΄ λ μνΌλ PyTorch 2.3.0 μ΄μμ΄ νμν©λλ€ .
15
15
"""
16
16
17
17
###############################################################################
18
18
# ``torch.utils.swap_tensors``
19
19
# ----------------------------
20
- # ``torch.utils.swap_tensors`` (hereafter referred to as ``swap_tensors``) is a
21
- # utility function that takes in two Python tensors and swaps them .
20
+ # ``torch.utils.swap_tensors`` (μ΄ν ``swap_tensors``λ‘ μΈκΈλ¨)μ
21
+ # λ κ°μ νμ΄μ¬ ν
μλ₯Ό μ
λ ₯λ°μ μλ‘ κ΅ννλ μ νΈλ¦¬ν° ν¨μμ
λλ€ .
22
22
23
23
import torch
24
24
import torch .nn as nn
29
29
print (f"After swapping, t1: { t1 } , t2: { t2 } " )
30
30
31
31
################################################################################
32
- # More specifically , ``swap_tensors`` swaps the Python ``__class__``, ``__dict__``
33
- # and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.
32
+ # λ ꡬ체μ μΌλ‘ , ``swap_tensors``λ λ ν
μμ νμ΄μ¬ ``__class__``, ``__dict__``μ
33
+ # ``__slots__``λΏλ§ μλλΌ κ΄λ ¨λ ``at::Tensor``λ κ΅νν©λλ€ .
34
34
#
35
35
#
36
- # Application to ``nn.Module``
36
+ # ``nn.Module``μμ μ μ©
37
37
# ----------------------------
38
- # This utility is pertinent to ``nn.Module`` when a Python object outside
39
- # of the module holds a reference to parameters of the module. If an ``nn.Module``
40
- # modifies any of its parameters out of place, the object holding references to
41
- # the parameters will not see the change. A classic example of this is the
42
- # optimizer, which holds a reference to the parameters of the ``nn.Module``.
43
- # This leads to a silent correctness issue where the ``optimizer.step()`` will
44
- # run without error but the weights of the ``nn.Module`` will not be updated .
38
+ # μ΄ μ νΈλ¦¬ν°λ λͺ¨λ μΈλΆμ νμ΄μ¬ κ°μ²΄κ° λͺ¨λμ νλΌλ―Έν°μ λν
39
+ # μ°Έμ‘°λ₯Ό 보μ νκ³ μμ λ ``nn.Module``μ κ΄λ ¨μ΄ μμ΅λλ€. λ§μ½ ``nn.Module``
40
+ # μ΄ νλΌλ―Έν°λ₯Ό μ μ리μ μμ νλ©΄, νλΌλ―Έν°μ λν μ°Έμ‘°λ₯Ό 보μ ν κ°μ²΄λ
41
+ # λ³κ²½ μ¬νμ λ³Ό μ μμ΅λλ€. κ³ μ μ μΈ μλ‘λ ``nn.Module``μ νλΌλ―Έν°μ λν
42
+ # μ°Έμ‘°λ₯Ό 보μ νλ μ΅ν°λ§μ΄μ κ° μμ΅λλ€. μ΄λ‘ μΈν΄ ``optimizer.step()``μ΄
43
+ # μ€λ₯ μμ΄ μ€νλμ§λ§, ``nn.Module``μ κ°μ€μΉλ μ
λ°μ΄νΈλμ§ μλ
44
+ # 무μ±μ μ νμ± λ¬Έμ λ₯Ό μ΄λν μ μμ΅λλ€ .
45
45
46
46
mod = torch .nn .Linear (1 , 2 , bias = False )
47
47
optimizer = torch .optim .SGD (mod .parameters ())
52
52
print (f"weight in optimizer: { optimizer .param_groups [0 ]['params' ]} " )
53
53
54
54
################################################################################
55
- # ``nn.Module.to()`` and related methods
55
+ # ``nn.Module.to()`` λ° κ΄λ ¨ λ©μλ
56
56
# --------------------------------------
57
- # This includes methods that change the device of the module (such as ``nn.Module.cpu()``),
58
- # methods that change the ``dtype`` of the module (such as ``nn.Module.float()``)
59
- # as well as methods that allow the module to be materialized
60
- # (such as ``nn.Module.to_empty()``).
57
+ # μ¬κΈ°μλ λͺ¨λμ λλ°μ΄μ€λ₯Ό λ³κ²½νλ λ©μλ(μ: ``nn.Module.cpu()``),
58
+ # λͺ¨λμ ``dtype``μ λ³κ²½νλ λ©μλ(μ: ``nn.Module.float()``)
59
+ # λΏλ§ μλλΌ λͺ¨λμ ꡬ체νν μ μκ² ν΄μ£Όλ λ©μλ
60
+ # (μ: ``nn.Module.to_empty()``)κ° ν¬ν¨λ©λλ€ .
61
61
#
62
- # At first glance, it might be non-intuitive that these methods are able to
63
- # modify the parameters of the module in-place. The existing approach has been
64
- # to use a nasty hack dating back from the first days of PyTorch .
62
+ # μ²μμλ μ΄λ¬ν λ©μλκ° λͺ¨λμ νλΌλ―Έν°λ₯Ό μ μ리μμ μμ ν μ μλ€λ κ²μ΄
63
+ # μ§κ΄μ μ΄μ§ μμ μ μμ΅λλ€. κΈ°μ‘΄μ μ κ·Ό λ°©μμ PyTorch μ΄κΈ°λΆν° μ¬μ©λ
64
+ # 볡μ‘ν ν΄νΉ λ°©λ²μ μ¬μ©νμ΅λλ€ .
65
65
#
66
- # Notably, the existing approach does not work in these cases :
66
+ # νΉν, κΈ°μ‘΄ μ κ·Ό λ°©μμ λ€μκ³Ό κ°μ κ²½μ°μ μλνμ§ μμ΅λλ€ :
67
67
#
68
- # * when using ``__torch_dispatch__`` subclasses
69
- # * when ``param`` and ``new_param`` do not have the same Python ``type()``
70
- # * For tensors with special C++ representations (such as sparse tensors and ``XLA`` tensors )
68
+ # * ``__torch_dispatch__`` μλΈν΄λμ€λ₯Ό μ¬μ©ν λ
69
+ # * ``param``κ³Ό ``new_param``μ νμ΄μ¬ ``type()``μ΄ λμΌνμ§ μμ λ
70
+ # * νΉμ C++ ννμ κ°μ§ ν
μ(μ: ν¬μ ν
μ λ° ``XLA`` ν
μ )
71
71
#
72
- # In the following part of this recipe, we will define a toy ``__torch_dispatch__``
73
- # subclass ``MyQuantizedLinearWeight `` that represents quantized linear weights .
74
- # This subclass will be used for illustration purposes throughout the rest of
75
- # the tutorial. For brevity, we omit most of the ``__torch_dispatch__``
76
- # implementation .
72
+ # μ΄ λ μνΌμ λ€μ λΆλΆμμλ μμνλ μ ν κ°μ€μΉλ₯Ό λνλ΄λ
73
+ # μ₯λκ° ``__torch_dispatch__ `` μλΈν΄λμ€ ``MyQuantizedLinearWeight``λ₯Ό μ μν κ²μ
λλ€ .
74
+ # μ΄ μλΈν΄λμ€λ νν 리μΌμ λλ¨Έμ§ λΆλΆμμ μ€λͺ
μ μν΄ μ¬μ©λ©λλ€.
75
+ # κ°κ²°ν¨μ μν΄ λλΆλΆμ ``__torch_dispatch__``
76
+ # ꡬνμ μλ΅ν©λλ€ .
77
77
aten = torch .ops .aten
78
78
79
79
class MyQuantizedLinearWeight (torch .Tensor ):
@@ -108,10 +108,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
108
108
raise NotImplementedError (f"Unsupported function { func } " )
109
109
110
110
#################################################################################
111
- # Let us create an ``nn.Linear`` layer of ``dtype`` ``torch.float32 `` where the weight is
112
- # a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.
113
- # Observe that the weight's ``dtype`` changes as expected. However, the ``dtype``
114
- # of the subclass' payload (``elem``) does not change .
111
+ # ``dtype``κ° ``torch.float32``μΈ ``nn.Linear `` λ μ΄μ΄λ₯Ό μμ±νκ³ , κ°μ€μΉλ₯Ό
112
+ # ``MyQuantizedLinearWeight``λ‘ μ€μ ν ν, μ΄λ₯Ό ``torch.bfloat16``μΌλ‘ λ³νν΄ λ΄
λλ€ .
113
+ # κ°μ€μΉμ ``dtype``μ΄ μμλλ‘ λ³κ²½λλ κ²μ κ΄μ°°ν μ μμ΅λλ€. κ·Έλ¬λ
114
+ # μλΈν΄λμ€μ νμ΄λ‘λ (``elem``)μ ``dtype``μ λ³κ²½λμ§ μμ΅λλ€ .
115
115
116
116
m = nn .Linear (3 , 5 , dtype = torch .float32 )
117
117
m .weight = torch .nn .Parameter (MyQuantizedLinearWeight (m .weight , 0.5 ))
@@ -123,12 +123,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
123
123
print (f"m.bias.dtype: { m .bias .dtype } " )
124
124
125
125
################################################################################
126
- # To this end, we introduce a global config
127
- # ``torch.__future__.set_swap_module_params_on_conversion`` that will use
128
- # ``swap_tensors`` to swap the parameters of the module while preserving
129
- # references in place of ``.data`` setting. When this config is set ,
130
- # ``swap_tensors`` will be used during the conversion, which ensures that
131
- # the ``dtype`` of the payload is properly converted .
126
+ # μ΄λ₯Ό μν΄ κΈλ‘λ² κ΅¬μ±μ λμ
ν©λλ€
127
+ # ``torch.__future__.set_swap_module_params_on_conversion``μ μ¬μ©ν κ²μ
λλ€.
128
+ # μ΄ κ΅¬μ±μ ``swap_tensors``λ₯Ό μ¬μ©νμ¬ λͺ¨λμ 맀κ°λ³μλ₯Ό κ΅ννλ©°,
129
+ # ``.data`` μ€μ λμ μ°Έμ‘°λ₯Ό 보쑴ν©λλ€. μ΄ κ΅¬μ±μ΄ μ€μ λλ©΄ ,
130
+ # λ³ν κ³Όμ μμ ``swap_tensors``κ° μ¬μ©λλ©°, μ΄λ₯Ό ν΅ν΄
131
+ # νμ΄λ‘λμ ``dtype``μ΄ μ¬λ°λ₯΄κ² λ³νλλλ‘ λ³΄μ₯ν©λλ€ .
132
132
133
133
torch .__future__ .set_swap_module_params_on_conversion (True )
134
134
m = nn .Linear (3 , 5 , dtype = torch .float32 )
@@ -144,42 +144,42 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
144
144
################################################################################
145
145
# ``nn.Module.load_state_dict()``
146
146
# --------------------------------
147
- # Depending on the value of the ``assign`` keyword argument passed
148
- # to ``load_state_dict()``, there are two ways to load the ``state_dict`` :
147
+ # ``load_state_dict()``μ μ λ¬λ ``assign`` ν€μλ μΈμμ κ°μ λ°λΌ,
148
+ # ``state_dict``λ₯Ό λ‘λνλ λ κ°μ§ λ°©λ²μ΄ μμ΅λλ€ :
149
149
#
150
- # * ``assign=False``: preserves the properties of ``module.param`` and only takes the values
151
- # from ``state_dict['param_name']``
152
- # * ``assign=True``: preserves the properties and values of ``state_dict['param_name']``.
150
+ # * ``assign=False``: ``module.param``μ μμ±μ 보쑴νκ³ , ``state_dict['param_name']``μ
151
+ # κ°λ§ κ°μ Έμ΅λλ€.
152
+ # * ``assign=True``: ``state_dict['param_name']``μ μμ±κ³Ό κ°μ λͺ¨λ 보쑴ν©λλ€ .
153
153
#
154
154
#
155
- # Previously, these were implemented with in-place ``copy_`` and ``__setattr__`` respectively .
156
- # With the existing implementation, each approach had its own limitations -- ``assign=False``
157
- # imposes the constraint that the type of the parameter in the ``state_dict`` must
158
- # be the same as the type of the parameter in the module while ``assign=True`` imposes
159
- # the constraint that anything that holds references to the module's parameters must
160
- # be initialized after ``nn.Module.load_state_dict()``.
155
+ # μ΄μ μλ κ°κ° μ μ리μμ ``copy_``μ ``__setattr__``λ‘ κ΅¬νλμμ΅λλ€ .
156
+ # κΈ°μ‘΄ ꡬνμμλ κ°κ°μ μ κ·Ό λ°©μμ κ³ μ ν μ ν μ¬νμ΄ μμμ΅λλ€ -- ``assign=False``λ
157
+ # ``state_dict``μ 맀κ°λ³μ νμ
μ΄
158
+ # λͺ¨λμ 맀κ°λ³μ νμ
κ³Ό λμΌν΄μΌ νλ€λ μ μ½μ λΆκ³Όνλ λ°λ©΄, ``assign=True``λ
159
+ # λͺ¨λμ 맀κ°λ³μμ λν μ°Έμ‘°λ₯Ό 보μ νλ λͺ¨λ κ²μ΄
160
+ # ``nn.Module.load_state_dict()`` μ΄νμ μ΄κΈ°νλμ΄μΌ νλ€λ μ μ½μ λΆκ³Όν©λλ€ .
161
161
#
162
- # Now, we address both constraints by adding a ``swap_tensors`` path to ``load_state_dict()``
163
- # and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.
164
- # When the ``swap_tensors`` path is enabled via the ``__future__ `` mentioned above ,
165
- # we can use a ``__torch_function__`` handler for ``module_load `` to apply a
166
- # custom transformation to the value in the ``state_dict``. The result of this
167
- # transformation will be swapped with the parameter in the module .
162
+ # μ΄μ μ°λ¦¬λ ``load_state_dict()``μ ``swap_tensors`` κ²½λ‘λ₯Ό μΆκ°νμ¬ λ κ°μ§ μ μ½μ ν΄κ²°ν©λλ€.
163
+ # κ·Έλ¦¬κ³ μλ‘μ΄ νμ₯ ν¬μΈνΈ ``torch.Tensor.module_load(self, other, assign=False)``λ₯Ό λμ
ν©λλ€ .
164
+ # μμμ μΈκΈν ``__future__``λ₯Ό ν΅ν΄ ``swap_tensors `` κ²½λ‘κ° νμ±νλλ©΄ ,
165
+ # ``module_load``μ λν ``__torch_function__ `` νΈλ€λ¬λ₯Ό μ¬μ©νμ¬
166
+ # ``state_dict``μ κ°μ μ¬μ©μ μ μ λ³νμ μ μ©ν μ μμ΅λλ€. μ΄ λ³νμ κ²°κ³Όλ
167
+ # λͺ¨λμ 맀κ°λ³μμ κ΅μ²΄λ©λλ€ .
168
168
#
169
- # In the following example, we will use the ``MyQuantizedLinearWeight`` subclass
170
- # defined above to illustrate how we can use these features to apply a
171
- # custom quantization scheme to the weights of a linear layer when
172
- # loading the ``state_dict``.
169
+ # λ€μ μμ μμλ ``MyQuantizedLinearWeight`` μλΈν΄λμ€λ₯Ό μ¬μ©νμ¬
170
+ # μμμ μ μλ κΈ°λ₯μ μ¬μ©νμ¬
171
+ # μ ν λ μ΄μ΄μ κ°μ€μΉμ μ¬μ©μ μ μ μμν λ°©μμ μ μ©νλ λ°©λ²μ 보μ¬μ€λλ€.
172
+ # ``state_dict``λ₯Ό λ‘λν λ .
173
173
#
174
- # Recall that the ``__torch_function__`` handler for ``module_load `` will be
175
- # invoked if either ``self`` or ``other`` (in this case ``param`` or
176
- # ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses .
174
+ # ``module_load``μ λν ``__torch_function__ `` νΈλ€λ¬λ νΈμΆλ©λλ€.
175
+ # ``self`` λλ ``other`` (μ΄ κ²½μ° ``param`` λλ
176
+ # ``state_dict[param_key]``)κ° ``MyQuantizedLinearWeight`` μλΈν΄λμ€μΈ κ²½μ° .
177
177
#
178
- # Assume that we expect the ``state_dict`` to contain plain tensors and the
179
- # module to contain ``MyQuantizedLinearWeight`` parameters where we want the
180
- # tensors in the ``state_dict`` to be transformed into the subclass. Then we
181
- # can define a ``__torch_function__`` handler for ``torch.Tensor.module_load``
182
- # as such :
178
+ # ``state_dict``κ° μΌλ° ν
μλ₯Ό ν¬ν¨νκ³ μλ€κ³ κ°μ νκ³ ,
179
+ # λͺ¨λμ΄ ``MyQuantizedLinearWeight`` νλΌλ―Έν°λ₯Ό ν¬ν¨νκ³ μμΌλ©°,
180
+ # ``state_dict``μ ν
μκ° μλΈν΄λμ€λ‘ λ³νλκΈ°λ₯Ό μν©λλ€. κ·ΈλΌ,
181
+ # μ°λ¦¬λ ``torch.Tensor.module_load``μ λν ``__torch_function__`` νΈλ€λ¬λ₯Ό λ€μκ³Ό κ°μ΄ μ μν μ μμ΅λλ€:
182
+ # λ€μκ³Ό κ°μ΄ :
183
183
184
184
@classmethod
185
185
def custom_torch_function (cls , func , types , args = (), kwargs = None ):
@@ -196,9 +196,9 @@ def custom_torch_function(cls, func, types, args=(), kwargs=None):
196
196
MyQuantizedLinearWeight .__torch_function__ = custom_torch_function
197
197
198
198
#################################################################################
199
- # First, let us create a skeleton of a model on the meta device to avoid
200
- # materializing storages. We convert all weights in the modules to
201
- # `` MyQuantizedLinearWeight`` subclasses while leaving biases intact .
199
+ # λ¨Όμ , λ©ν λλ°μ΄μ€μμ λͺ¨λΈμ μ€μΌλ ν€μ μμ±νμ¬ μ μ₯μλ₯Ό μ€μ²΄ννλ κ²μ νΌν©μλ€.
200
+ # μ μ₯μλ₯Ό μ€μ²΄ννμ§ μμ΅λλ€. μ°λ¦¬λ λͺ¨λμ λͺ¨λ κ°μ€μΉλ₯Ό
201
+ # `MyQuantizedLinearWeight` μλΈν΄λμ€λ‘ λ³ννλ©΄μ λ°μ΄μ΄μ€λ κ·Έλλ‘ μ μ§ν©λλ€ .
202
202
203
203
def fn (m ):
204
204
if isinstance (m , nn .Linear ):
@@ -212,9 +212,9 @@ def fn(m):
212
212
m .apply (fn )
213
213
214
214
#################################################################################
215
- # We can then load the ``state_dict``. Observe that we use ``assign=True`` because
216
- # for biases, we want to preserve the properties of the tensor in the ``state_dict``
217
- # (for example, we do not want the bias to be on the ``meta`` device after loading ).
215
+ # κ·Έλ¬λ©΄ ``state_dict``λ₯Ό λ‘λν μ μμ΅λλ€. λ°μ΄μ΄μ€μ κ²½μ° ``assign=True``λ₯Ό μ¬μ©νλλ°,
216
+ # λ°μ΄μ΄μ€μ κ²½μ°, ``state_dict``μ μλ ν
μμ μμ±μ μ μ§νκ³ μ ν©λλ€.
217
+ # ``state_dict``μ μλ ν
μμ μμ±μ μ μ§νκΈ° μν΄μμ
λλ€ (μλ₯Ό λ€μ΄, λ‘λ ν λ°μ΄μ΄μ€κ° ``meta`` λλ°μ΄μ€μ μμ§ μλλ‘ ).
218
218
219
219
torch .__future__ .set_swap_module_params_on_conversion (True )
220
220
print (f"Before: id(weight)={ id (m .weight )} , id(bias)={ id (m .bias )} " )
@@ -226,16 +226,16 @@ def fn(m):
226
226
print (f"m.state_dict() after load_state_dict():\n { m .state_dict ()} " )
227
227
228
228
#################################################################################
229
- # The above is a toy example of how we can use the new extension point in
230
- # ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such
231
- # as when we have tensor subclasses in the ``state_dict`` and plain ``nn.Parameters``/
232
- # tensors in the module or when both are tensor subclasses. Based on the use
233
- # case, we can define the ``__torch_function__`` handler for ``module_load``
234
- # to apply the transforms as needed .
229
+ # μμ μμ λ ``nn.Module.load_state_dict()``μμ μλ‘μ΄ νμ₯ μ§μ μ μ¬μ©νλ λ°©λ²μ 보μ¬μ£Όλ μ₯λκ° μμ μ
λλ€.
230
+ # ``nn.Module.load_state_dict()``μμ μλ‘μ΄ νμ₯ μ§μ μ μ¬μ©νλ λ°©λ²μ 보μ¬μ€λλ€. λν λ€λ₯Έ μλ리μ€λ₯Ό μμν μλ μμ΅λλ€.
231
+ # μλ₯Ό λ€μ΄, ``state_dict``μ ν
μ μλΈν΄λμ€κ° μκ³ λͺ¨λμ μΌλ° ``nn.Parameters``/
232
+ # λͺ¨λμ ν
μκ° μκ±°λ λ λ€ ν
μ μλΈν΄λμ€μΌ λ λ± λ€μν μλ리μ€λ₯Ό μμν μ μμ΅λλ€. μ¬μ©μ λ°λΌ
233
+ # μλ리μ€μ λ°λΌ ``module_load``μ λν ``__torch_function__`` νΈλ€λ¬λ₯Ό μ μν μ μμ΅λλ€.
234
+ # νμμ λ°λΌ λ³νμ μ μ©ν©λλ€ .
235
235
#
236
236
# Conclusion
237
237
# ----------
238
- # In this recipe, we learned about ``swap_tensors``, the importance
239
- # of preserving references for parameters in ``nn.Module`` as well as how to
240
- # use the two new extension points that are gated by
241
- # ``torch.__future__.set_swap_module_params_on_conversion``.
238
+ # μ΄λ² λ μνΌμμλ ``swap_tensors``μ ``nn.Module``μμ νλΌλ―Έν°μ μ°Έμ‘°λ₯Ό 보쑴νλ κ²μ μ€μμ±μ λν΄ λ°°μ μ΅λλ€.
239
+ # ``nn.Module``μμ νλΌλ―Έν°μ μ°Έμ‘°λ₯Ό 보쑴νλ κ²κ³Ό
240
+ # μ μ΄λλ λ κ°μ§ μλ‘μ΄ νμ₯ μ§μ μ μ¬μ©νλ λ°©λ²μ λν΄μλ λ°°μ μ΅λλ€.
241
+ # ``torch.__future__.set_swap_module_params_on_conversion``μ μν΄
0 commit comments