Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massively reduce LayerNorm/RMSNorm training memory usage by sharing saved tensor with other parts of the networks #430

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

RuiWang1998
Copy link

Saving the output of the normalization instead of the input reduces the memory cost in modern networks, where the output is going to be saved anyways (e.g., a Linear layer) and the input is only needed here otherwise.

We do this by changing the norm backward kernels to load from output instead of the input, and to compute normalized tensor from output instead of input. To stabilize gradients, we also clamp by magnitude the gamma value for division.

However, for now it seems that it comes at a price at a somewhat lower numerical precision. To be investiagated further.

For now we pass the operator tests if we increase the tolerance for layer normalization's dgamma and dx.

Effect

From the sequence parallel paper, layer norm costs 4sbh in memory per layer during training. However, if we save the output of the layer instead of the input, we can forget (free) those as they are not needed anywhere else. This amounts to ~1/6 of total activation cost of a transformer model.

Note that this may results in slightly higher numerical errors because we are using output for gradients by the time we use which, the rounding errors may propagate.

TODO

For now we are passing the operator tests, the tests in test_numerics.py (with added LayerNorm tests and gradients tests). We will very soon add more features including supports for fp8 and integration with python frameworks.
However, we note that this might require rethinking the dynamics of LayerNorm+Linear/MLP in the presence of fp8. Maybe eventually will just make the normalization layers output half/single precision data and save those instead of the inputs.

Also

See NVIDIA/apex#1715

Saving the output of the normalization instead of the input reduces the
memory cost in modern networks, where the output is going to be saved
anyways (e.g., a Linear layer) and the input is only needed here
otherwise.

We do this by changing the norm backward kernels to load from output
instead of the input, and to compute normalized tensor from output
instead of input.  To stabilize gradients, we also clamp by magnitude
the gamma value for division.

For now we passed LayerNorm and RMSNorm tests with higher numerical
differences which is expected as we use output to compute the gradient
and the numerical errors propagate that way.

Signed-off-by: Rui Wang <[email protected]>
@ptrendx ptrendx self-requested a review September 11, 2023 16:10
Read/Write FP8 output in the backward/forward pass of the layer norm for
even faster and more efficient forward/backward runs

Signed-off-by: Rui Wang <[email protected]>
I also added a fp8 bwd implementation for both norms

Signed-off-by: Rui Wang <[email protected]>
Also fixed a bug in calling the bwd function
Should have been removed in the last commit

Signed-off-by: Rui Wang <[email protected]>
@RuiWang1998 RuiWang1998 force-pushed the rui/dev-mem-eff-ln-operator branch from 27688b4 to c3a2042 Compare September 13, 2023 03:20
@ptrendx
Copy link
Member

ptrendx commented Sep 18, 2023

@timmoon10 Could you help in reviewing this PR?

Fix linting for normalization.cu and other python files
Revert import statements to its original version
Also a tiny fix on inference of layernorm

Signed-off-by: Rui Wang <[email protected]>
Signed-off-by: Rui Wang <[email protected]>
LayerNorm inference used a different api in torch

Signed-off-by: Rui Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants