-
Notifications
You must be signed in to change notification settings - Fork 568
[VLM] Add token-imbalance loss #1803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0ec62b1
to
9015c60
Compare
ignore_index=IGNORE_INDEX, | ||
) | ||
num_tokens = (labels != IGNORE_INDEX).sum() | ||
avg_num_tokens_per_rank = dist_mean(num_tokens, token_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this dist_mean
call, it seems it'll trigger a GPU/CPU sync
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() |
I think this will potentially bring unnecessary perf issues, as without with CPU can stay way ahead of GPU.
I'd recommend refactoring the .item()
call outside _dist_reduce
and put them into callsites. Alternatively you can directly call funcol.all_reduce()
here.
cc @fegin this won't work with FT as the ft_pg
is not visible to this loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an extra process group mimicking the distributed utils. Hope that makes sense.
) | ||
num_tokens = (labels != IGNORE_INDEX).sum() | ||
avg_num_tokens_per_rank = dist_mean(num_tokens, token_mesh) | ||
return sum_loss / avg_num_tokens_per_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a cute & "mostly correct" way to deal with the imbalanced token loss issue. However,
- it's not the most readable one
- moreover, I don't think it is correct if gradient accumulation is enabled, as each microbatch can have different amount of "avg_num_tokens"
I think the best way should be
- don't let FSDP do implicit gradient division
- always run cross entropy with
reduction="sum"
- let data loader / trainer count the number of tokens involving in loss computation, e.g. by explicitly doing
num_tokens = (labels != IGNORE_INDEX).sum()
on each rank. (I agree that without imbalance we don't need to do this and the followed communication.)
This way we also don't need this ad hoc call https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/parallelize.py#L363
We don't need to do this refactor for now, but it would be good if you could leave a TODO item here + file an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed #1842
from torchtitan.tools.logging import logger | ||
|
||
|
||
IGNORE_INDEX = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid defining IGNORE_INDEX
in two places when they actually need to be shared.
The other appearance is in torchtitan/experiments/vlm/datasets/mm_collator_nld.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I keep the def in loss.py as it makes most sense and import to the dataloader files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I delete the other PR #1802 from this PR as they are fairly independent to make it cleaner to merge.
I will fix the other PR to import IGNORE_INDEX depend on which ones land first.
each rank computes the loss over **only its local tokens** and returns an | ||
*average* over those tokens: | ||
|
||
Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for writing up the docstring, looks very good.
8eb28c1
to
31dd8d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks.
This PR address this via loss balancing, which incur an additional comm in the loss computation.
In practice, I haven't notice any impacts from this comm.
Quick sanity check
Let have a sum loss of all tokens on each rank i, with$N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$
If we multiply the loss on each rank by a constant factor c (the same for all ranks), then after
backward()
:FSDP will average these gradients across ranks:
We want this to equal the global‑sample average:
Thus for FSDP gradient to be correct, we need
So the right scaling factor is$R/N_{\text{total}}$ , which mean divide the per-rank sum loss with $N_{\text{total}}/R$ , which is average number of tokens per rank.
Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens on that rank, we now divide by the average number of tokens across all rank
P/s: sorry this PR is based on #1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.