Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor
)


def build_cross_entropy_loss(job_config: JobConfig):
def build_cross_entropy_loss(job_config: JobConfig, **kwargs):
del kwargs # delete any unused arguments
loss_fn = cross_entropy_loss
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
Expand Down
1 change: 0 additions & 1 deletion torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from dataclasses import asdict, replace

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
Expand Down
113 changes: 113 additions & 0 deletions torchtitan/experiments/vlm/infra/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import distributed as dist
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.components.ft.manager import FTManager
from torchtitan.config.job_config import JobConfig
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.tools.logging import logger


IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy


# WARNING: currently this does not take into account gradient accumulation
# and the gradient can still be biased toward grad accum step with less valid tokens
# See: https://github.com/pytorch/torchtitan/issues/1842
def token_imbalance_ce_loss(
pred: torch.Tensor,
labels: torch.Tensor,
token_mesh: DeviceMesh,
ft_pg: dist.ProcessGroup | None,
) -> torch.Tensor:
"""
Cross‑entropy loss that is *robust* to varying numbers of valid tokens across ranks.

In a typical distributed training setup (data parallel + sequence parallel),
each rank computes the loss over **only its local tokens** and returns an
*average* over those tokens:

Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing up the docstring, looks very good.

across all ranks, the resulting update is equivalent to a **global sample
average** *only if every rank contains the same number of tokens*.
In practice that assumption is violated for many workloads:
- Sequences are padded to a fixed length -> some ranks see fewer real tokens.
- SFT finetuning where user's queries tokens are masked out.
- Vision encoders often injects a large number of “ignored”
tokens as context that are not trained with text tokens' loss.

This function fixes the issue by **scaling the sum-of-loss** with the *average*
number of non‑ignored tokens per rank, computed via an all-reduce over
`token_mesh`. The returned scalar therefore represents the loss that would
be obtained if every token in the entire distributed batch contributed with
equal weight to the global gradient, regardless of how many padded or
ignored tokens each rank contains.

Parameters
----------
pred : torch.Tensor
labels : torch.Tensor
token_mesh : DeviceMesh
A device mesh that contains all ranks participating in this training step's
loss computation. The function performs an ``all_reduce`` (mean) over the
`num_tokens` tensor of a rank across this mesh.
ft_pg: dist.ProcessGroup | None
Optional pg for Fault Tolerance training.

Returns
-------
torch.Tensor
A scalar loss tensor, ready for ``backward()`` and FSDP all-reduce mean

Notes
-----
* The function internally uses :func:`torch.nn.functional.cross_entropy`
with ``reduction="sum"`` so that each token contributes exactly once to
the numerator. The denominator is the **average** number of valid tokens
per rank, not the local count.
* If a rank contains no valid tokens (i.e., all labels are ``IGNORE_INDEX``),
its contribution to the sum is zero and its `num_tokens` becomes zero.
In that case the mean across ranks will still be well‑defined as long as
at least one rank has non‑zero token count.
"""
sum_loss = torch.nn.functional.cross_entropy(
pred.flatten(0, 1).float(),
labels.flatten(0, 1),
reduction="sum",
ignore_index=IGNORE_INDEX,
)
num_tokens = (labels != IGNORE_INDEX).sum()
avg_num_tokens_per_rank = funcol.all_reduce(
num_tokens, reduceOp=c10d.ReduceOp.AVG.name, group=token_mesh
)
if ft_pg is not None:
avg_num_tokens_per_rank = funcol.all_reduce(
avg_num_tokens_per_rank, reduceOp=c10d.ReduceOp.AVG.name, group=ft_pg
)
return sum_loss / avg_num_tokens_per_rank
Copy link
Contributor

@tianyu-l tianyu-l Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cute & "mostly correct" way to deal with the imbalanced token loss issue. However,

  1. it's not the most readable one
  2. moreover, I don't think it is correct if gradient accumulation is enabled, as each microbatch can have different amount of "avg_num_tokens"

I think the best way should be

  • don't let FSDP do implicit gradient division
  • always run cross entropy with reduction="sum"
  • let data loader / trainer count the number of tokens involving in loss computation, e.g. by explicitly doing num_tokens = (labels != IGNORE_INDEX).sum() on each rank. (I agree that without imbalance we don't need to do this and the followed communication.)

This way we also don't need this ad hoc call https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/parallelize.py#L363

We don't need to do this refactor for now, but it would be good if you could leave a TODO item here + file an issue.

cc @ezyang @fegin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #1842



def build_token_imbalance_ce_loss(
job_config: JobConfig, parallel_dims: ParallelDims, ft_manager: FTManager, **kwargs
):
del kwargs # delete any unused arguments
# NOTE: The device mesh where the input tokens w/ shape BSD can be sliced:
# DP split the batch dim B
# CP split the sequence dim S
token_mesh = parallel_dims.world_mesh["dp_cp"]
ft_pg = ft_manager.loss_sync_pg
loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg)
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
return loss_fn
4 changes: 3 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(self, job_config: JobConfig):
init_device = device_type
buffer_device = None

self.loss_fn = self.train_spec.build_loss_fn(job_config)
self.loss_fn = self.train_spec.build_loss_fn(
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)

# verify batch sizes
global_batch_size = job_config.training.global_batch_size
Expand Down