Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add efficient Cross-Entropy by cuda kernel to accelerate training speed and reduce cross-entropy memory usage during training. #995

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions tests/pytorch/test_efficient_memory_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import torch
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be converted to use pytest similar to the remaining testing files. We would also need to add this to qa/L0_pytorch_unittest/test.sh to run this test in the CI.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytest has been added now, THX.

import transformer_engine.pytorch as te
import transformer_engine_torch as tex


def test_cross_entropy_fwd_sum_exp_torch(vocab_parallel_logits, max_logit):

# step2: substraction max_logit
vocab_parallel_logits = vocab_parallel_logits - max_logit.unsqueeze(dim=-1)
# step 3: exp
exp_logits = torch.exp(vocab_parallel_logits)
# step 4: sum
ret = torch.sum(exp_logits, dim=-1)
return ret


def check_cross_entropy_fwd_sum_exp_cuda():
# cuda kernel logic
s, b, v = 3, 1, 1024
vocab_parallel_logits = torch.randn(s, b, v).to(torch.bfloat16).cuda() # bf16
arr = [0.4, 0.5, 0.6]
for i in range(3):
vocab_parallel_logits[i] = arr[i]
# vocab_parallel_logits.fill_(0.55)
vocab_parallel_logits.to(torch.bfloat16)

logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
logits_max = logits_max.to(torch.float32)
logits_max.fill_(0.45).to(torch.float32)
n_dim = vocab_parallel_logits.size(-1)

sum_exp_logits = tex.cross_entropy_fwd_sum_exp_cuda(vocab_parallel_logits, logits_max)
print(sum_exp_logits.shape)

print()

sum_exp_logits_torch = test_cross_entropy_fwd_sum_exp_torch(vocab_parallel_logits, logits_max)
print(sum_exp_logits_torch.shape)
# print(torch.allclose(sum_exp_logits, sum_exp_logits_torch))


def test_cross_entropy_fwd_mean_log_torch(vocab_parallel_logits, max_logit, sum_exp_logits):

vocab_parallel_logits = vocab_parallel_logits - max_logit.unsqueeze(dim=-1)
exp_logits = torch.exp(vocab_parallel_logits)
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)

return mean_log_probs


def check_cross_entropy_fwd_mean_log_cuda():
# cuda kernel logic
s, b, v = 1024, 4, 256000
vocab_parallel_logits = torch.randn(s, b, v).to(torch.bfloat16).cuda() # bf16
# arr = [0.023, 0.643, 0.195]
# for i in range(3):
# vocab_parallel_logits[i] = arr[i]
vocab_parallel_logits.fill_(0.55)
vocab_parallel_logits.to(torch.bfloat16)

logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
logits_max = logits_max.to(torch.float32)
logits_max.fill_(0.73).to(torch.float32)

sum_exp_logits = torch.empty_like(logits_max)
sum_exp_logits.fill_(3).to(torch.float32)

n_dim = vocab_parallel_logits.size(-1)

mean_log_probs = tex.cross_entropy_fwd_mean_log_cuda(
vocab_parallel_logits, logits_max, sum_exp_logits
)
print(mean_log_probs)

print("-------------------------")

mean_log_probs_torch = test_cross_entropy_fwd_mean_log_torch(
vocab_parallel_logits, logits_max, sum_exp_logits
)
print(mean_log_probs_torch)
# print(torch.allclose(mean_log_probs, mean_log_probs_torch))


def check_cross_entropy_bwd_cuda():
# cuda kernel logic
s, b, v = 3, 1, 1025
input_ptr = torch.randn(s, b, v).to(torch.bfloat16).cuda() # bf16
arr = [0.090, 0.777, 0.595]
for i in range(3):
input_ptr[i] = arr[i]

# input_ptr.fill_(0.55)
input_ptr.to(torch.bfloat16)

logits_max = torch.max(input_ptr, dim=-1)[0]
logits_max = logits_max.to(torch.float32)
logits_max.fill_(0.7).to(torch.float32)

sum_exp_logits = torch.empty_like(logits_max)
sum_exp_logits.fill_(3).to(torch.float32)

label_smoothing = 0.12
vocab_size = 666
grad_output_ptr = torch.empty_like(logits_max)
grad_output_ptr.fill_(0.88).to(torch.float32)

target_mask_ptr = torch.empty_like(logits_max, dtype=torch.bool)
target_mask_ptr.fill_(0).to(torch.bool)

masked_target_1d_ptr = torch.empty_like(logits_max, dtype=torch.int64)
masked_target_1d_ptr.fill_(1).to(torch.int64)
masked_target_1d_ptr = masked_target_1d_ptr.view(-1)

n_dim = input_ptr.size(-1)

grad_input_ptr = tex.cross_entropy_bwd_cuda(
grad_output_ptr,
input_ptr,
target_mask_ptr,
masked_target_1d_ptr,
logits_max,
sum_exp_logits,
label_smoothing,
vocab_size,
)
print(grad_input_ptr)


if __name__ == "__main__":
# check_cross_entropy_fwd_sum_exp_cuda()
# check_cross_entropy_fwd_mean_log_cuda()
check_cross_entropy_bwd_cuda()
18 changes: 18 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,21 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
bool wd_after_momentum, float scale);

#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

/***************************************************************************************************
* Support memory efficient cross entropy for Megatron-LM
**************************************************************************************************/

at::Tensor cross_entropy_forward_sum_exp(const at::Tensor &vocab_parallel_logits_ptr,
const at::Tensor &logits_max_ptr);

at::Tensor cross_entropy_fwd_mean_log(const at::Tensor &vocab_parallel_logits_ptr,
const at::Tensor &logits_max_ptr,
const at::Tensor &sum_exp_logits_ptr);

at::Tensor cross_entropy_bwd(const at::Tensor &grad_output_ptr,
const at::Tensor &input_ptr, //vocab_parallel_logits_ptr
const at::Tensor &target_mask_ptr,
const at::Tensor &masked_target_1d_ptr,
const at::Tensor &logits_max_ptr, const at::Tensor &sum_exp_logits_ptr,
float label_smoothing, size_t vocab_size);
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#include "../extensions.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp");
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log");
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp");
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log");
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward");
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp",
py::call_guard<py::gil_scoped_release>());
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log",
py::call_guard<py::gil_scoped_release>());
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward",
py::call_guard<py::gil_scoped_release>());

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi ksivaman, Thank you very much for taking the time to review the code for me!

I have already made the modifications now.

// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
py::call_guard<py::gil_scoped_release>());
Expand Down
Loading