Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] Implement fp32 accumulation for attention with context parallel in both forward and backward pass. #821

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Yuxin-CV
Copy link

Summary:
When using context parallelism, we've observed that adopting fp32 accumulation for attention operations in both the forward and backward passes significantly improves numerical accuracy. This approach aligns with practices seen in projects like Megatron-LM, which employs a similar strategy for bf16 precision view source. Our findings extend this benefit to fp16 precision as well.

Detailed Analysis:
By enforcing stricter tolerances using torch.testing.assert_close, with rtol = 1.3e-6 and atol = 1e-5 (the default for fp32), we can better assess improvements in our computations.

Our test configuration is as follows:

Kernel Backend: FlashAttention
Model Configuration: ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias")
World Size: 8
Data Type: bf16
QKV Format: bshd
Docker Environment: NGC Docker 24.03-py3

Baseline Testing:
Utilizing the latest TransformerEngine commit, we identified significant mismatches and differences in both absolute and relative terms as shown in our test results.

Modifications and Results:

  • Attention Forward Pass (torch.testing.assert_close(out_, out, **tols)):
before
[rank7]: Mismatched elements: 487642 / 786432 (62.0%)
[rank7]: Greatest absolute difference: 0.001953125 at index (0, 0, 23, 1430) (up to 1e-05 allowed)
[rank7]: Greatest relative difference: 2784.0 at index (0, 1, 97, 826) (up to 1.3e-06 allowed)

after
[rank7]: Mismatched elements: 398818 / 786432 (50.7%)
[rank7]: Greatest absolute difference: 0.0009765625 at index (0, 0, 4, 1135) (up to 1e-05 allowed)
[rank7]: Greatest relative difference: 2896.0 at index (0, 1, 97, 826) (up to 1.3e-06 allowed)

After modifying the forward pass to use an fp32 accumulation buffer, the attention out mismatches decreased from 62.0% to 50.7%, and the greatest absolute and relative differences reduced significantly.
Review commit

  • Attention Backward Pass: (torch.testing.assert_close(torch.cat((dq_, dk_, dv_), 0), torch.cat((dq, dk, dv), 0), **tols))
before
[rank7]: Mismatched elements: 1391190 / 2359296 (59.0%)
[rank7]: Greatest absolute difference: 0.001953125 at index (0, 0, 5, 7, 34) (up to 1e-05 allowed)
[rank7]: Greatest relative difference: 6176.0 at index (0, 0, 219, 1, 98) (up to 1.3e-06 allowed)

after
[rank7]: Mismatched elements: 975076 / 2359296 (41.3%)
[rank7]: Greatest absolute difference: 0.001953125 at index (0, 0, 181, 3, 67) (up to 1e-05 allowed)
[rank7]: Greatest relative difference: 5600.0 at index (0, 0, 219, 1, 98) (up to 1.3e-06 allowed)

Adjustments to the backward pass, where dq & dkv accumulations now use an fp32 precision buffer, further reduced mismatches from 59.0% to 41.3%.
Review commit

Conclusion:
These changes enhance the stability of attention computations in context-parallel operations, with a marked improvement in numerical precision across tests.

@Yuxin-CV
Copy link
Author

Yuxin-CV commented Apr 28, 2024

This PR has successfully passed all scenarios covered by the L1_pytorch_context_parallel_test. Additionally, using the fp32 buffer resulted in a minimal increase in memory usage, less than 2%, which is negligible.

…ng fp32 precision when employing context parallelism

Signed-off-by: Yuxin-CV <[email protected]>
…precision buffer when employing context parallelism

Signed-off-by: Yuxin-CV <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant