-
Notifications
You must be signed in to change notification settings - Fork 568
Open
Description
From the discussion here: #1803 (comment)
Summary: currently, the loss is computed locally per rank, then FSDP average gradient across all ranks. This works for simple case where all ranks have the same amount of valid tokens, but breakdown in more complex scenarios:
- VLM training with native resolution. [VLM] Add token-imbalance loss #1803
- Expert Parallelism https://github.com/pytorch/torchtitan/pull/1793/files#r2415207041
Outline solution:
- don't let FSDP do implicit gradient division
- always run cross entropy with reduction="sum"
- let data loader / trainer count the number of tokens involving in loss computation, e.g. by explicitly doing num_tokens = (labels != IGNORE_INDEX).sum() on each rank. (I agree that without imbalance we don't need to do this and the followed communication.)
fegin