@@ -135,98 +135,121 @@ def step(self, closure: Any = None) -> None:
135
135
super ().step (closure )
136
136
self ._step_num += 1
137
137
138
- @torch .no_grad ()
139
138
def clip_grad_norm_ (self ) -> Optional [Union [float , torch .Tensor ]]:
140
139
"""Clip the gradient norm of all parameters."""
141
- max_norm = self ._max_gradient
142
- norm_type = float (self ._norm_type )
140
+
141
+ # converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'.
142
+ norm_type_float = float (self ._norm_type )
143
143
all_grads = []
144
144
total_grad_norm = None
145
145
146
+ sharded_params = self ._sharded_params
147
+ replicate_params = self ._replicate_params
148
+
146
149
# Process distributed parameters and gradients
147
- for pgs , dist_params in self ._sharded_params .items ():
148
- sharded_grads = [
149
- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
150
- for p in dist_params
151
- if p .grad is not None and p .grad .numel () > 0
152
- ]
153
- if len (sharded_grads ) == 0 :
154
- continue
150
+ for dist_params in sharded_params .values ():
151
+ sharded_grads = _get_grads (dist_params )
155
152
all_grads .extend (sharded_grads )
156
153
157
- sharded_grad_norm = _batch_cal_norm (
158
- sharded_grads ,
159
- max_norm ,
160
- norm_type ,
161
- pgs ,
162
- )
163
- total_grad_norm = (
164
- sharded_grad_norm
165
- if total_grad_norm is None
166
- else (
167
- torch .maximum (total_grad_norm , sharded_grad_norm )
168
- if norm_type == torch .inf
169
- else total_grad_norm + sharded_grad_norm
170
- )
171
- )
172
-
173
- square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174
-
175
154
# Process replicated parameters and gradients
176
- if self ._replicate_params :
177
- replicated_grads = [
178
- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
179
- for p in self ._replicate_params
180
- if p .grad is not None and p .grad .numel () > 0
181
- ]
182
- all_grads .extend (replicated_grads )
183
-
184
- replicated_grad_norm = _batch_cal_norm (
185
- replicated_grads ,
186
- max_norm ,
187
- norm_type ,
188
- None ,
189
- )
190
- total_grad_norm = (
191
- replicated_grad_norm
192
- if total_grad_norm is None
193
- else (
194
- torch .maximum (total_grad_norm , replicated_grad_norm )
195
- if norm_type == torch .inf
196
- else total_grad_norm + replicated_grad_norm
197
- )
198
- )
199
- square_replicated_grad_norm = replicated_grad_norm
200
- else :
201
- square_replicated_grad_norm = 0
202
-
203
- global log_grad_norm
204
- if log_grad_norm :
205
- if total_grad_norm is not None and norm_type != torch .inf :
206
- # pyre-ignore[58]
207
- grad_norm = total_grad_norm ** (1.0 / norm_type )
208
- else :
209
- grad_norm = total_grad_norm
155
+ if replicate_params :
156
+ replicate_grads = _get_grads (replicate_params )
157
+ all_grads .extend (replicate_grads )
210
158
211
- rank = dist .get_rank ()
212
- logger .info (
213
- f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214
- )
215
-
216
- # Aggregation
217
- if total_grad_norm is None :
218
- return
159
+ total_grad_norm = _compute_total_norm (
160
+ replicate_params , sharded_params , norm_type_float , self ._max_gradient
161
+ )
219
162
220
- if norm_type != torch .inf :
221
- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222
- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223
163
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224
- clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
164
+ clip_coef = cast (torch .Tensor , self . _max_gradient / (total_grad_norm + 1e-6 ))
225
165
clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226
166
torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227
167
return total_grad_norm
228
168
229
169
170
+ def _get_grads (
171
+ param_list : List [torch .Tensor ],
172
+ ) -> List [torch .Tensor ]:
173
+ """Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
174
+ grads = [
175
+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
176
+ for p in param_list
177
+ if p .grad is not None and p .grad .numel () > 0
178
+ ]
179
+ return grads
180
+
181
+
182
+ def _compute_total_norm (
183
+ replicate_params : Optional [List [torch .Tensor ]] = None ,
184
+ sharded_params : Optional [Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]]] = None ,
185
+ norm_type : float = 2.0 , # can be a normal float, or torch.inf
186
+ max_grad_norm : float = 1.0 ,
187
+ ) -> torch .Tensor :
188
+ """
189
+ Given both replicate params and sharded params, compute the total norm of the gradients of the full replicate params and the
190
+ full sharded param (parameters with a process group).
191
+
192
+ Args:
193
+ replicate_params (List[torch.Tensor]): list of replicate params
194
+ sharded_params (Optional[Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]]): dict that maps each process group to a list of sharded params
195
+ norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm.
196
+ max_grad_norm (float): max gradient norm.
197
+ """
198
+
199
+ ## compute |W|^p corresponding to all replicate params W
200
+
201
+ if replicate_params is None :
202
+ replicate_params = []
203
+ if sharded_params is None :
204
+ sharded_params = defaultdict (list )
205
+
206
+ def get_grad_norm_power (
207
+ param_list : List [torch .Tensor ],
208
+ norm_type : float ,
209
+ max_grad_norm : float ,
210
+ pgs : Optional [Tuple [dist .ProcessGroup ]] = None ,
211
+ ) -> torch .Tensor :
212
+ """
213
+ Given a list of parameters, convert them to local tensors if they are DTensors,
214
+ and compute the squared (or p-th power) norm of the gradients of the parameters.
215
+ """
216
+ grad_list = _get_grads (param_list )
217
+ return _batch_cal_norm (grad_list , max_grad_norm , norm_type , pgs )
218
+
219
+ ## compute the norm |W|^p corresponding to all sharded params W
220
+ sharded_grad_norm : torch .Tensor = torch .tensor (0.0 )
221
+ if sharded_params :
222
+ combine_sharded_norm_operator = (
223
+ torch .maximum if norm_type == torch .inf else torch .add
224
+ )
225
+
226
+ # We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
227
+ # this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
228
+ # because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
229
+ # specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
230
+ for pgs , dist_params in sharded_params .items ():
231
+ shard_norm = get_grad_norm_power (dist_params , norm_type , max_grad_norm , pgs )
232
+ sharded_grad_norm = combine_sharded_norm_operator (
233
+ sharded_grad_norm .to (shard_norm .device ), shard_norm
234
+ )
235
+
236
+ # Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
237
+ replicate_grad_norm : torch .Tensor = (
238
+ get_grad_norm_power (replicate_params , norm_type , max_grad_norm )
239
+ if replicate_params
240
+ else torch .tensor (0.0 )
241
+ ).to (sharded_grad_norm .device )
242
+
243
+ combine_norm_operator = (
244
+ torch .maximum
245
+ if norm_type == torch .inf
246
+ else lambda a , b : torch .add (a , b ).pow (1.0 / norm_type )
247
+ )
248
+
249
+ total_grad_norm = combine_norm_operator (replicate_grad_norm , sharded_grad_norm )
250
+ return total_grad_norm
251
+
252
+
230
253
def _batch_cal_norm (
231
254
grad_list : List [torch .Tensor ],
232
255
max_norm : float ,
0 commit comments