Skip to content

Conversation

lkhphuc
Copy link
Contributor

@lkhphuc lkhphuc commented Oct 6, 2025

In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks.
This PR address this via loss balancing, which incur an additional comm in the loss computation.
In practice, I haven't notice any impacts from this comm.

Quick sanity check

Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the loss on each rank by a constant factor c (the same for all ranks), then after backward():

$$ \tilde g_i = c \cdot g_i . $$

FSDP will average these gradients across ranks:

$$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$

We want this to equal the global‑sample average:

$$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$

Thus for FSDP gradient to be correct, we need

$$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$

So the right scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is average number of tokens per rank.
Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens on that rank, we now divide by the average number of tokens across all rank

P/s: sorry this PR is based on #1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 6, 2025
@lkhphuc lkhphuc force-pushed the vlm-loss branch 2 times, most recently from 0ec62b1 to 9015c60 Compare October 8, 2025 05:55
ignore_index=IGNORE_INDEX,
)
num_tokens = (labels != IGNORE_INDEX).sum()
avg_num_tokens_per_rank = dist_mean(num_tokens, token_mesh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this dist_mean call, it seems it'll trigger a GPU/CPU sync

return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()

I think this will potentially bring unnecessary perf issues, as without with CPU can stay way ahead of GPU.

I'd recommend refactoring the .item() call outside _dist_reduce and put them into callsites. Alternatively you can directly call funcol.all_reduce() here.

cc @fegin this won't work with FT as the ft_pg is not visible to this loss function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an extra process group mimicking the distributed utils. Hope that makes sense.

)
num_tokens = (labels != IGNORE_INDEX).sum()
avg_num_tokens_per_rank = dist_mean(num_tokens, token_mesh)
return sum_loss / avg_num_tokens_per_rank
Copy link
Contributor

@tianyu-l tianyu-l Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cute & "mostly correct" way to deal with the imbalanced token loss issue. However,

  1. it's not the most readable one
  2. moreover, I don't think it is correct if gradient accumulation is enabled, as each microbatch can have different amount of "avg_num_tokens"

I think the best way should be

  • don't let FSDP do implicit gradient division
  • always run cross entropy with reduction="sum"
  • let data loader / trainer count the number of tokens involving in loss computation, e.g. by explicitly doing num_tokens = (labels != IGNORE_INDEX).sum() on each rank. (I agree that without imbalance we don't need to do this and the followed communication.)

This way we also don't need this ad hoc call https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/parallelize.py#L363

We don't need to do this refactor for now, but it would be good if you could leave a TODO item here + file an issue.

cc @ezyang @fegin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #1842

from torchtitan.tools.logging import logger


IGNORE_INDEX = -100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid defining IGNORE_INDEX in two places when they actually need to be shared.
The other appearance is in torchtitan/experiments/vlm/datasets/mm_collator_nld.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep the def in loss.py as it makes most sense and import to the dataloader files.

Copy link
Contributor Author

@lkhphuc lkhphuc Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I delete the other PR #1802 from this PR as they are fairly independent to make it cleaner to merge.
I will fix the other PR to import IGNORE_INDEX depend on which ones land first.

each rank computes the loss over **only its local tokens** and returns an
*average* over those tokens:

Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing up the docstring, looks very good.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tianyu-l tianyu-l merged commit 98d904f into pytorch:main Oct 9, 2025
8 checks passed
@lkhphuc lkhphuc deleted the vlm-loss branch October 10, 2025 02:52
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 15, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants