Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add efficient Cross-Entropy by cuda kernel to accelerate training speed and reduce cross-entropy memory usage during training. #995

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

cb521
Copy link

@cb521 cb521 commented Jul 8, 2024

Description

  Hi, we found that in the cross-entropy implementation of Megatron-LLM, the input tensor needs to be converted to float for subsequent calculations. This will result in redundant GPU memory usage and time consumption. We used the CUDA kernel to fuse the torch op for two forward logic and one backward logic for cross-entropy computation. We performed float type conversion within the kernel and fused multiple independent calculation logics into one set of calculation logic. 

  To achieve this optimization, we developed  threes kernel in TransformerEngine and made corresponding changes in Megatron-LLM. At the same time, in order to test the performance of the cuda kernel, we also implemented the OpenAI Triton version of the kernel in Megatron-LLM and compared it with the cuda kernel.

  Finally, after our experimental verification. We found that using the CUDA kernel to optimize the current cross-entropy implementation can effectively improve training speed and reduce GPU memory usage. The test results are shown below:

cross-entropy-perf

As we can see, there are currently two cross-entropy implementations in Megatron-llm, one is the most primitive "original" and the other is called "fused". The implementation of "fused" is faster than "original" , but it consumes the most GPU memory. The two new implementations we have added this time, "Triton" and "CUDA Kernel," have faster training speeds and save moreGPU memory compared to the existing implementations of Megatron-LLM.

We also compared the convergence of the four implementations and found that the loss curves were basically the same, indicating that there was no problem with the calculation accuracy.
cross-entropy-loss

Notes:

  1. We trained with 8-H100 for 4 hours to conduct the performance and accuracy tests mentioned above.

  2. Wandb testing link: https://wandb.ai/megatron-core-moe-dev/binc-efficient-cross-entropy-no-moe?nw=nwuserbinc521

  3. Wandb report link: https://api.wandb.ai/links/megatron-core-moe-dev/m1qovycf

  4. Megatron-LLM MR: https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/1846

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

  • Change A : For TE, we added three kernels, "CrossEntropyFwdSumExpKernel" + "CrossEntropyFwdMeanLogKernel" + "CrossEntropyBwdKernel".
  • Change B: For Mcore, we added a new cross-entropy, and added some if logics.

@cb521 cb521 changed the title Add efficient cross entropy by cuda kernel. Add efficient Cross-Entropy by cuda kernel to accelerate training speed and reduce cross-entropy memory usage during training. Jul 30, 2024
Comment on lines 13 to 18
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp");
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log");
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp");
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log");
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward");
// Efficeint memory softmax cross entropy
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp,
"Softmax Cross_entropy Forward Sum & Exp",
py::call_guard<py::gil_scoped_release>());
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log,
"Softmax Cross_entropy Forward Mean & Log",
py::call_guard<py::gil_scoped_release>());
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward",
py::call_guard<py::gil_scoped_release>());

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi ksivaman, Thank you very much for taking the time to review the code for me!

I have already made the modifications now.

@@ -0,0 +1,135 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import torch
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch

@@ -0,0 +1,135 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be converted to use pytest similar to the remaining testing files. We would also need to add this to qa/L0_pytorch_unittest/test.sh to run this test in the CI.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytest has been added now, THX.

&vocab_parallel_logits_ptr[cur_vocab_parallel_logits_ptr_begin + i]);
dtype* bf_16_p = reinterpret_cast<dtype*>(&int4_arr);
#pragma unroll
for (int k = 0; k < 8; k++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You assume here that dtype is 2B long - could we generalize this, e.g. with

Suggested change
for (int k = 0; k < 8; k++) {
for (int k = 0; k < sizeof(int4)/sizeof(dtype); k++) {

Copy link
Author

@cb521 cb521 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi ptrendx, in the logic of the corresponding Megatron-LLM, its input can only be bfloat16, and we use int4 to vectorize the read, so we can assume that the loop condition here is 8. @ptrendx

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, sure, but Transformer Engine is used not only with Megatron-LM, so we should not make unnecessary assumptions about the usage (especially since it does not actually cost us anything to be more general as the logic change is very simple). Also, such hardcoding could the lead to hard-to-track errors if something actually changes in the datatypes used.

Comment on lines +259 to +265
1024 | 1
7 | 1016 | 2
6 | 1016 | 3

For example:
1024 | 1 -> [0,1023] [1024]
7 | 1016 | 2 -> [0, 6], [7,1022], [1023,1024]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have some text here in addition to the numbers? I do not quite understand what they represent.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me describe this process. @ptrendx

Because I will use vectorized reading in the upcoming calculations, reading in 8 bfloat16 at once.So I need to ensure that the read data is aligned with 16 bytes. In the cross-entropy computing scenario, ndim is usually a multiple of 8.

But in order to be compatible with future situations where the ndim may not be a multiple of 8, I used mathematical formulas to calculate the start and end offset addresses for each row that are aligned by 16 bytes.
And for every row's misaligned situations, I will use a simple loop to handle them separately at the kernel's end.

That is, ss our current threading model is a block responsible for computing one line, and for the case where the ndim is not an integer multiple of 8.Then we need to divide each row into three parts. The first and last part are not 16 byte aligned, and the middle part is 16 byte aligned. By this way, we can read the middle part according to vectorization and process the beginning and last part separately.

For example, if ndim is 1025, because 1025 is not a multiple of 8, it is necessary to divide the 1025 data into "three parts". For the first block is responsible for calculating the first row, there are only two parts. The elements in positions [0-1023] are aligned with 16 bytes, but the last one needs to be processed separatel.For the second block is responsible for calculating the second row, there are only three parts. The elements in positions [7, 1022] are aligned with 16 bytes. but the [0, 6] and the [1023,1024] needs to be processed separatel.

We can understand the calculation process of these indexes and values by combining them with the code.
As shown in the following topology figure,
1024 | 1
7 | 1016 | 2
6 | 1016 | 3

if the ndim = 1025, assuming we have three rows, the first row is cur_vocab_parallel_logits_ptr_begin = 0, cur_vocab_parallel_logits_ptr_end = 1025, end_mol_num = 1, begin_mol_num = 1024, begin_offset = 0, end_offset = 1023.
We define the 16 byte aligned region as [begin_offset, end_offset], for the first row, this region is [0, 1023]. That is, the first 1024 elements can be read by vectorized reading.

Similarly, for the second row, cur_vocab_parallel_logits_ptr_end = 1025, cur_vocab_parallel_logits_ptr_end = 2050, end_mol_num = 2, begin_mol_num = 1023, begin_offset = 7, end_offset = 1022
We define the 16 byte aligned region as [begin_offset, end_offset], for the second row, this region is [7, 1022]. That is , the 1016 numbers in the middle. For the first 7 and last 2 numbers of the current row, they need to be processed separately because they are not aligned with 16 bytes.

The calculation logic in the third row is also the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I did not yet read the actual note (will do in a second), but such explanations should be recorded somewhere possibly easier to find than a comment in a PR. At the very least we could move them to the PR description and in the code make a comment with a link to that PR?

__shared__ typename BlockReduceT::TempStorage temp_storage;

#pragma unroll
for (size_t i = begin_offset + tid * 8; i <= end_offset - 7; i += 8 * BlockSize) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below, you assume 2B per data element.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, same as the explanation above.

size_t cur_vocab_parallel_logits_ptr_end =
rowIdx * n_dim + n_dim; //cur_vocab_parallel_logits_ptr_end = 1025, 2050

size_t end_mol_num = cur_vocab_parallel_logits_ptr_end % 8; //end_mol_num = 1, end_mol_num = 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is end_mol_num (and similarly begin_mol_num)? Please name the variable in a way that tells the purpose of it. The comments as they are right now do not help at all and I do not understand them (why is end_mol_num 1 and then 2?).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Sorry, the previous comments were not easy to understand. I explained the process in my response above. @ptrendx

#pragma unroll
for (int k = 0; k < 8; k++) {
dtype data_bf16 = bf_16_p[k];
float data_fp32 = float(data_bf16); //convert to float
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float data_fp32 = float(data_bf16); //convert to float
float data_fp32 = static_cast<float>(data_bf16);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int k = 0; k < 8; k++) {
dtype data_bf16 = bf_16_p[k];
float data_fp32 = float(data_bf16); //convert to float
row_item = exp(data_fp32 - cur_row_max);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
row_item = exp(data_fp32 - cur_row_max);
row_item = expf(data_fp32 - cur_row_max);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ptrendx , We need to use the no-fast version of exp instead of the fast version expf here, because after our testing, exp may affect the accuracy of certain calculation results. In actual megatron-llm training tasks, we found that the loss using expf was slightly higher than our baseline.

#pragma unroll
for (size_t k = cur_vocab_parallel_logits_ptr_begin;
k < (cur_vocab_parallel_logits_ptr_begin + begin_offset); k++) {
float val = float(vocab_parallel_logits_ptr[k]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float val = float(vocab_parallel_logits_ptr[k]);
float val = static_cast<float>(vocab_parallel_logits_ptr[k]);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, @ptrendx Thank you for your reminder. I have changed all float type conversions in the code to static_cast<float>.

for (size_t k = cur_vocab_parallel_logits_ptr_begin;
k < (cur_vocab_parallel_logits_ptr_begin + begin_offset); k++) {
float val = float(vocab_parallel_logits_ptr[k]);
row_item = exp(val - cur_row_max);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
row_item = exp(val - cur_row_max);
row_item = expf(val - cur_row_max);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +323 to +324
float val = float(vocab_parallel_logits_ptr[k]);
row_item = exp(val - cur_row_max);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


float row_sum = BlockReduceT(temp_storage).Sum(cur_thread_exp_sum);

if (threadIdx.x == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to use multiple threads to do it (and then use shfl to do reduction inside a warp)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea. @ptrendx Using multiple threads should accelerate the computation here and avoid warp divergence. I think this can be treated as an independent modification after you have no issues with the other parts.

But I need to add that, in fact, as I mentioned before, ndim is currently known to be a multiple of 8. So generally speaking, it will not enter the loop logic here for calculation.

Comment on lines +333 to +335
float row_item = exp(data_fp32 - cur_row_max);
row_item = row_item / cur_row_exp_sum; //compute softmax
row_item = log(row_item); //after softmax, compute log
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float row_item = exp(data_fp32 - cur_row_max);
row_item = row_item / cur_row_exp_sum; //compute softmax
row_item = log(row_item); //after softmax, compute log
float row_item = expf(data_fp32 - cur_row_max);
row_item = row_item / cur_row_exp_sum; //compute softmax
row_item = logf(row_item); //after softmax, compute log

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

template <typename dtype, int BlockSize>
void __global__ CrossEntropyFwdMeanLogKernel(float* mean_log_probs_ptr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have basically the same requests as for the previous kernel.

Also, would it be possible to merge those 2 kernels via templating for example in order to not duplicate code? They seem mostly the same?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ptrendx From the perspective of the entire computational logic, they are all elementwise operations. In fact, apart from the op of elementwise, these two kernels are indeed very similar.

But, the usage scenario of these kernels are in Megatron-LLM, and we implemented these kernels to replace some torch ops. And our kernel is aimed at ordinary Python users. In order to make function naming clearer, I suggest implementing the kernel separately for better clarity. How do you think?
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, on the Python side those would be exposed as separate functions like you have in this screenshot, no argument here. But internally the kernel code itself could be shared. Basically a structure like this:

template <func>
__global__ void kernel(...) { 
  ...
  // use func to do computation
  ...
}

at::Tensor cross_entropy_fwd_sum_exp( ... ) {
  ...
  kernel<sum_exp>( ... );
  ...
}

at::Tensor cross_entropy_fwd_mean_log( ... ) {
  ...
  kernel<mean_log>( ... );
  ...
}

Comment on lines +406 to +408
row = exp(row - logits_max);
row /= sum_exp_logits;
if (i == (size_t)masked_target_1d) { // i == masked_target_1d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
row = exp(row - logits_max);
row /= sum_exp_logits;
if (i == (size_t)masked_target_1d) { // i == masked_target_1d
row = expf(row - logits_max);
row /= sum_exp_logits;
if (i == static_cast<size_t>(masked_target_1d)) {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (i == (size_t)masked_target_1d) { // i == masked_target_1d
row = row - softmax_update;
}
if (label_smoothing > 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use smoothing here - you waste a register by holding onto this value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry I didn't understand what you meant. Could you explain it specifically? THX @ptrendx

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You pass both label_smoothing and smoothing to this function, whereas smoothing value seems to be directly tied to label_smoothing value via

smoothing = label_smoothing * vocab_size / (vocab_size - 1);

formula. The condition label_smoothing > 0 is equivalent to smoothing > 0, right? So you could save a register by just forgetting the label_smoothing value after the initial setting of smoothing (or maybe even just pass the value of smoothing to the kernel in the first place to not perform this computation by all the threads) and just use smoothing instead:

if (smoothing > 0) {
  row -= smoothing * average_grad;
}

(considering that it seems smoothing can be only either 0 or >0 you could even just remove the conditional completely and always do row -= smoothing * average_grad;).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants