Skip to content

[release/2.2] [ROCm] Correct numerical issues in layer norm backwards kernel (#140259) #1767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2024

Conversation

jataylo
Copy link

@jataylo jataylo commented Dec 4, 2024

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel cuLoadWriteStridedInputs which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated mean and rstd from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER

…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
@jataylo jataylo changed the title [ROCm] Correct numerical issues in layer norm backwards kernel (#140259) [release/2.2] [ROCm] Correct numerical issues in layer norm backwards kernel (#140259) Dec 4, 2024
@rocm-repo-management-api
Copy link

Jenkins build for 5ef76a334d756b4e912e48c71d11ba4afb2e2887 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony merged commit ce8fba1 into release/2.2 Dec 6, 2024
3 of 5 checks passed
@pruthvistony pruthvistony deleted the rel22-picks-jack branch December 6, 2024 05:58
@jithunnair-amd
Copy link
Collaborator

cherry-pick --onto release/2.1

rocm-mici pushed a commit that referenced this pull request Dec 6, 2024
… kernel (pytorch#140259) (#1767)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER
@rocm-mici
Copy link

Created branch release/2.1_cherry-pick_pr-1767 and #1777

@inemankov
Copy link

!cherry-pick --onto release/2.3 release/2.4 release/2.5

@rocm-mici
Copy link

Can't perform the cherry-pick keyword: unexpected error

@inemankov
Copy link

!cherry-pick --onto release/2.3 release/2.4 release/2.5

rocm-mici pushed a commit that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1767)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER
@rocm-mici
Copy link

Nothing to cherry-pick onto the release/2.3 branch

Nothing to cherry-pick onto the release/2.4 branch

Created branch release/2.5_cherry-pick_pr-1767 and #1794

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants