Skip to content

Commit 2a44370

Browse files
committed
early all-reduce total_norm in non-PP grad norm clipping
ghstack-source-id: cf1729c Pull Request resolved: #769
1 parent 3f20451 commit 2a44370

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

torchtitan/utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,18 @@ def clip_grad_norm_(
384384
grads, norm_type, error_if_nonfinite, foreach
385385
)
386386

387-
if pp_mesh is not None:
388-
if isinstance(total_norm, DTensor):
389-
# will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor
390-
391-
# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
392-
# we can simply reduce the DTensor to get the total norm in this tensor's process group
393-
# and then convert it to a local tensor
394-
total_norm = total_norm.full_tensor()
387+
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
388+
# We can simply reduce the DTensor to get the total norm in this tensor's process group
389+
# and then convert it to a local tensor.
390+
# NOTE: It has two purposes:
391+
# 1. to make sure the total norm is computed correctly when PP is used (see below)
392+
# 2. to return a reduced total_norm tensor whose .item() would return the correct value
393+
if isinstance(total_norm, DTensor):
394+
# Will reach here if any non-PP parallelism is used.
395+
# If only using PP, total_norm will be a local tensor.
396+
total_norm = total_norm.full_tensor()
395397

396-
# TODO: cleanup maybe using DTensor
398+
if pp_mesh is not None:
397399
if math.isinf(norm_type):
398400
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
399401
else:

0 commit comments

Comments
 (0)