Skip to content

Commit da78ffb

Browse files
jatayloAMD AMD
authored and
AMD AMD
committed
[release/2.2] [ROCm] Correct numerical issues in layer norm backwards kernel (pytorch#140259) (#1767)
It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736) Fixes #ISSUE_NUMBER
1 parent abbfe77 commit da78ffb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

aten/src/ATen/native/cuda/layer_norm_kernel.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,8 @@ void cuLoadWriteStridedInputs(
840840
{
841841
int i1 = i1_block+thr_load_row_off;
842842
if (i1 < i1_end) {
843-
T curr_mean = mean[i1];
844-
T curr_rstd = rstd[i1];
843+
T_ACC curr_mean = mean[i1];
844+
T_ACC curr_rstd = rstd[i1];
845845
for (int k = 0; k < blockDim.y; ++k) {
846846
int i2 = i2_off + k;
847847
int load_idx = i1*N+i2;

0 commit comments

Comments
 (0)