-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add efficient Cross-Entropy by cuda kernel to accelerate training speed and reduce cross-entropy memory usage during training. #995
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
// Efficeint memory softmax cross entropy | ||
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp, | ||
"Softmax Cross_entropy Forward Sum & Exp"); | ||
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log, | ||
"Softmax Cross_entropy Forward Mean & Log"); | ||
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Efficeint memory softmax cross entropy | |
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp, | |
"Softmax Cross_entropy Forward Sum & Exp"); | |
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log, | |
"Softmax Cross_entropy Forward Mean & Log"); | |
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward"); | |
// Efficeint memory softmax cross entropy | |
m.def("cross_entropy_fwd_sum_exp_cuda", &cross_entropy_forward_sum_exp, | |
"Softmax Cross_entropy Forward Sum & Exp", | |
py::call_guard<py::gil_scoped_release>()); | |
m.def("cross_entropy_fwd_mean_log_cuda", &cross_entropy_fwd_mean_log, | |
"Softmax Cross_entropy Forward Mean & Log", | |
py::call_guard<py::gil_scoped_release>()); | |
m.def("cross_entropy_bwd_cuda", &cross_entropy_bwd, "Softmax Cross_entropy Backward", | |
py::call_guard<py::gil_scoped_release>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi ksivaman, Thank you very much for taking the time to review the code for me!
I have already made the modifications now.
@@ -0,0 +1,135 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch | |
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# See LICENSE for license information. | |
import torch |
@@ -0,0 +1,135 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be converted to use pytest
similar to the remaining testing files. We would also need to add this to qa/L0_pytorch_unittest/test.sh
to run this test in the CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytest has been added now, THX.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
&vocab_parallel_logits_ptr[cur_vocab_parallel_logits_ptr_begin + i]); | ||
dtype* bf_16_p = reinterpret_cast<dtype*>(&int4_arr); | ||
#pragma unroll | ||
for (int k = 0; k < 8; k++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You assume here that dtype is 2B long - could we generalize this, e.g. with
for (int k = 0; k < 8; k++) { | |
for (int k = 0; k < sizeof(int4)/sizeof(dtype); k++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi ptrendx, in the logic of the corresponding Megatron-LLM, its input can only be bfloat16, and we use int4 to vectorize the read, so we can assume that the loop condition here is 8. @ptrendx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, sure, but Transformer Engine is used not only with Megatron-LM, so we should not make unnecessary assumptions about the usage (especially since it does not actually cost us anything to be more general as the logic change is very simple). Also, such hardcoding could the lead to hard-to-track errors if something actually changes in the datatypes used.
1024 | 1 | ||
7 | 1016 | 2 | ||
6 | 1016 | 3 | ||
|
||
For example: | ||
1024 | 1 -> [0,1023] [1024] | ||
7 | 1016 | 2 -> [0, 6], [7,1022], [1023,1024] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have some text here in addition to the numbers? I do not quite understand what they represent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let me describe this process. @ptrendx
Because I will use vectorized reading in the upcoming calculations, reading in 8 bfloat16 at once.So I need to ensure that the read data is aligned with 16 bytes. In the cross-entropy computing scenario, ndim is usually a multiple of 8.
But in order to be compatible with future situations where the ndim may not be a multiple of 8, I used mathematical formulas to calculate the start and end offset addresses for each row that are aligned by 16 bytes.
And for every row's misaligned situations, I will use a simple loop to handle them separately at the kernel's end.
That is, ss our current threading model is a block responsible for computing one line, and for the case where the ndim is not an integer multiple of 8.Then we need to divide each row into three parts. The first and last part are not 16 byte aligned, and the middle part is 16 byte aligned. By this way, we can read the middle part according to vectorization and process the beginning and last part separately.
For example, if ndim is 1025, because 1025 is not a multiple of 8, it is necessary to divide the 1025 data into "three parts". For the first block is responsible for calculating the first row, there are only two parts. The elements in positions [0-1023] are aligned with 16 bytes, but the last one needs to be processed separatel.For the second block is responsible for calculating the second row, there are only three parts. The elements in positions [7, 1022] are aligned with 16 bytes. but the [0, 6] and the [1023,1024] needs to be processed separatel.
We can understand the calculation process of these indexes and values by combining them with the code.
As shown in the following topology figure,
1024 | 1
7 | 1016 | 2
6 | 1016 | 3
if the ndim = 1025, assuming we have three rows, the first row is cur_vocab_parallel_logits_ptr_begin
= 0, cur_vocab_parallel_logits_ptr_end
= 1025, end_mol_num
= 1, begin_mol_num
= 1024, begin_offset
= 0, end_offset
= 1023.
We define the 16 byte aligned region as [begin_offset, end_offset], for the first row, this region is [0, 1023]. That is, the first 1024 elements can be read by vectorized reading.
Similarly, for the second row, cur_vocab_parallel_logits_ptr_end
= 1025, cur_vocab_parallel_logits_ptr_end
= 2050, end_mol_num
= 2, begin_mol_num
= 1023, begin_offset
= 7, end_offset
= 1022
We define the 16 byte aligned region as [begin_offset, end_offset], for the second row, this region is [7, 1022]. That is , the 1016 numbers in the middle. For the first 7 and last 2 numbers of the current row, they need to be processed separately because they are not aligned with 16 bytes.
The calculation logic in the third row is also the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I did not yet read the actual note (will do in a second), but such explanations should be recorded somewhere possibly easier to find than a comment in a PR. At the very least we could move them to the PR description and in the code make a comment with a link to that PR?
__shared__ typename BlockReduceT::TempStorage temp_storage; | ||
|
||
#pragma unroll | ||
for (size_t i = begin_offset + tid * 8; i <= end_offset - 7; i += 8 * BlockSize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as below, you assume 2B per data element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, same as the explanation above.
size_t cur_vocab_parallel_logits_ptr_end = | ||
rowIdx * n_dim + n_dim; //cur_vocab_parallel_logits_ptr_end = 1025, 2050 | ||
|
||
size_t end_mol_num = cur_vocab_parallel_logits_ptr_end % 8; //end_mol_num = 1, end_mol_num = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is end_mol_num
(and similarly begin_mol_num
)? Please name the variable in a way that tells the purpose of it. The comments as they are right now do not help at all and I do not understand them (why is end_mol_num 1 and then 2?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Sorry, the previous comments were not easy to understand. I explained the process in my response above. @ptrendx
#pragma unroll | ||
for (int k = 0; k < 8; k++) { | ||
dtype data_bf16 = bf_16_p[k]; | ||
float data_fp32 = float(data_bf16); //convert to float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float data_fp32 = float(data_bf16); //convert to float | |
float data_fp32 = static_cast<float>(data_bf16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (int k = 0; k < 8; k++) { | ||
dtype data_bf16 = bf_16_p[k]; | ||
float data_fp32 = float(data_bf16); //convert to float | ||
row_item = exp(data_fp32 - cur_row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row_item = exp(data_fp32 - cur_row_max); | |
row_item = expf(data_fp32 - cur_row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ptrendx , We need to use the no-fast version of exp
instead of the fast version expf
here, because after our testing, exp
may affect the accuracy of certain calculation results. In actual megatron-llm training tasks, we found that the loss using expf
was slightly higher than our baseline.
#pragma unroll | ||
for (size_t k = cur_vocab_parallel_logits_ptr_begin; | ||
k < (cur_vocab_parallel_logits_ptr_begin + begin_offset); k++) { | ||
float val = float(vocab_parallel_logits_ptr[k]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float val = float(vocab_parallel_logits_ptr[k]); | |
float val = static_cast<float>(vocab_parallel_logits_ptr[k]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, @ptrendx Thank you for your reminder. I have changed all float type conversions in the code to static_cast<float>
.
for (size_t k = cur_vocab_parallel_logits_ptr_begin; | ||
k < (cur_vocab_parallel_logits_ptr_begin + begin_offset); k++) { | ||
float val = float(vocab_parallel_logits_ptr[k]); | ||
row_item = exp(val - cur_row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row_item = exp(val - cur_row_max); | |
row_item = expf(val - cur_row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
float val = float(vocab_parallel_logits_ptr[k]); | ||
row_item = exp(val - cur_row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
float row_sum = BlockReduceT(temp_storage).Sum(cur_thread_exp_sum); | ||
|
||
if (threadIdx.x == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be better to use multiple threads to do it (and then use shfl to do reduction inside a warp)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good idea. @ptrendx Using multiple threads should accelerate the computation here and avoid warp divergence. I think this can be treated as an independent modification after you have no issues with the other parts.
But I need to add that, in fact, as I mentioned before, ndim is currently known to be a multiple of 8. So generally speaking, it will not enter the loop logic here for calculation.
float row_item = exp(data_fp32 - cur_row_max); | ||
row_item = row_item / cur_row_exp_sum; //compute softmax | ||
row_item = log(row_item); //after softmax, compute log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float row_item = exp(data_fp32 - cur_row_max); | |
row_item = row_item / cur_row_exp_sum; //compute softmax | |
row_item = log(row_item); //after softmax, compute log | |
float row_item = expf(data_fp32 - cur_row_max); | |
row_item = row_item / cur_row_exp_sum; //compute softmax | |
row_item = logf(row_item); //after softmax, compute log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
template <typename dtype, int BlockSize> | ||
void __global__ CrossEntropyFwdMeanLogKernel(float* mean_log_probs_ptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have basically the same requests as for the previous kernel.
Also, would it be possible to merge those 2 kernels via templating for example in order to not duplicate code? They seem mostly the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ptrendx From the perspective of the entire computational logic, they are all elementwise operations. In fact, apart from the op of elementwise, these two kernels are indeed very similar.
But, the usage scenario of these kernels are in Megatron-LLM, and we implemented these kernels to replace some torch ops. And our kernel is aimed at ordinary Python users. In order to make function naming clearer, I suggest implementing the kernel separately for better clarity. How do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, on the Python side those would be exposed as separate functions like you have in this screenshot, no argument here. But internally the kernel code itself could be shared. Basically a structure like this:
template <func>
__global__ void kernel(...) {
...
// use func to do computation
...
}
at::Tensor cross_entropy_fwd_sum_exp( ... ) {
...
kernel<sum_exp>( ... );
...
}
at::Tensor cross_entropy_fwd_mean_log( ... ) {
...
kernel<mean_log>( ... );
...
}
row = exp(row - logits_max); | ||
row /= sum_exp_logits; | ||
if (i == (size_t)masked_target_1d) { // i == masked_target_1d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row = exp(row - logits_max); | |
row /= sum_exp_logits; | |
if (i == (size_t)masked_target_1d) { // i == masked_target_1d | |
row = expf(row - logits_max); | |
row /= sum_exp_logits; | |
if (i == static_cast<size_t>(masked_target_1d)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (i == (size_t)masked_target_1d) { // i == masked_target_1d | ||
row = row - softmax_update; | ||
} | ||
if (label_smoothing > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use smoothing here - you waste a register by holding onto this value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry I didn't understand what you meant. Could you explain it specifically? THX @ptrendx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You pass both label_smoothing
and smoothing
to this function, whereas smoothing value seems to be directly tied to label_smoothing value via
smoothing = label_smoothing * vocab_size / (vocab_size - 1);
formula. The condition label_smoothing > 0
is equivalent to smoothing > 0
, right? So you could save a register by just forgetting the label_smoothing
value after the initial setting of smoothing (or maybe even just pass the value of smoothing
to the kernel in the first place to not perform this computation by all the threads) and just use smoothing
instead:
if (smoothing > 0) {
row -= smoothing * average_grad;
}
(considering that it seems smoothing
can be only either 0 or >0 you could even just remove the conditional completely and always do row -= smoothing * average_grad;
).
Description
As we can see, there are currently two cross-entropy implementations in Megatron-llm, one is the most primitive "original" and the other is called "fused". The implementation of "fused" is faster than "original" , but it consumes the most GPU memory. The two new implementations we have added this time, "Triton" and "CUDA Kernel," have faster training speeds and save moreGPU memory compared to the existing implementations of Megatron-LLM.
We also compared the convergence of the four implementations and found that the loss curves were basically the same, indicating that there was no problem with the calculation accuracy.

Notes:
We trained with 8-H100 for 4 hours to conduct the performance and accuracy tests mentioned above.
Wandb testing link: https://wandb.ai/megatron-core-moe-dev/binc-efficient-cross-entropy-no-moe?nw=nwuserbinc521
Wandb report link: https://api.wandb.ai/links/megatron-core-moe-dev/m1qovycf
Megatron-LLM MR: https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/1846
Type of change
Changes
Please list the changes introduced in this PR: