@@ -384,16 +384,18 @@ def clip_grad_norm_(
384
384
grads , norm_type , error_if_nonfinite , foreach
385
385
)
386
386
387
- if pp_mesh is not None :
388
- if isinstance (total_norm , DTensor ):
389
- # will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor
390
-
391
- # if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
392
- # we can simply reduce the DTensor to get the total norm in this tensor's process group
393
- # and then convert it to a local tensor
394
- total_norm = total_norm .full_tensor ()
387
+ # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
388
+ # We can simply reduce the DTensor to get the total norm in this tensor's process group
389
+ # and then convert it to a local tensor.
390
+ # NOTE: It has two purposes:
391
+ # 1. to make sure the total norm is computed correctly when PP is used (see below)
392
+ # 2. to return a reduced total_norm tensor whose .item() would return the correct value
393
+ if isinstance (total_norm , DTensor ):
394
+ # Will reach here if any non-PP parallelism is used.
395
+ # If only using PP, total_norm will be a local tensor.
396
+ total_norm = total_norm .full_tensor ()
395
397
396
- # TODO: cleanup maybe using DTensor
398
+ if pp_mesh is not None :
397
399
if math .isinf (norm_type ):
398
400
dist .all_reduce (total_norm , op = dist .ReduceOp .MAX , group = pp_mesh .get_group ())
399
401
else :
0 commit comments