[Pytorch] Implement fp32 accumulation for attention with context parallel in both forward and backward pass. #821
+35
−17
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
When using context parallelism, we've observed that adopting fp32 accumulation for attention operations in both the forward and backward passes significantly improves numerical accuracy. This approach aligns with practices seen in projects like Megatron-LM, which employs a similar strategy for bf16 precision view source. Our findings extend this benefit to fp16 precision as well.
Detailed Analysis:
By enforcing stricter tolerances using
torch.testing.assert_close
, withrtol = 1.3e-6
andatol = 1e-5
(the default for fp32), we can better assess improvements in our computations.Our test configuration is as follows:
Baseline Testing:
Utilizing the latest TransformerEngine commit, we identified significant mismatches and differences in both absolute and relative terms as shown in our test results.
Modifications and Results:
torch.testing.assert_close(out_, out, **tols)
):After modifying the forward pass to use an fp32 accumulation buffer, the attention out mismatches decreased from 62.0% to 50.7%, and the greatest absolute and relative differences reduced significantly.
Review commit
torch.testing.assert_close(torch.cat((dq_, dk_, dv_), 0), torch.cat((dq, dk, dv), 0), **tols)
)Adjustments to the backward pass, where dq & dkv accumulations now use an fp32 precision buffer, further reduced mismatches from 59.0% to 41.3%.
Review commit
Conclusion:
These changes enhance the stability of attention computations in context-parallel operations, with a marked improvement in numerical precision across tests.