Massively reduce LayerNorm/RMSNorm training memory usage by sharing saved tensor with other parts of the networks #430
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Saving the output of the normalization instead of the input reduces the memory cost in modern networks, where the output is going to be saved anyways (e.g., a Linear layer) and the input is only needed here otherwise.
We do this by changing the norm backward kernels to load from output instead of the input, and to compute normalized tensor from output instead of input. To stabilize gradients, we also clamp by magnitude the gamma value for division.
However, for now it seems that it comes at a price at a somewhat lower numerical precision. To be investiagated further.
For now we pass the operator tests if we increase the tolerance for layer normalization's dgamma and dx.
Effect
From the sequence parallel paper, layer norm costs
4sbh
in memory per layer during training. However, if we save the output of the layer instead of the input, we can forget (free) those as they are not needed anywhere else. This amounts to ~1/6 of total activation cost of a transformer model.Note that this may results in slightly higher numerical errors because we are using output for gradients by the time we use which, the rounding errors may propagate.
TODO
For now we are passing the operator tests, the tests in
test_numerics.py
(with added LayerNorm tests and gradients tests). We will very soon add more features including supports for fp8 and integration with python frameworks.However, we note that this might require rethinking the dynamics of LayerNorm+Linear/MLP in the presence of fp8. Maybe eventually will just make the normalization layers output half/single precision data and save those instead of the inputs.
Also
See NVIDIA/apex#1715