Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyTorch] Infrastructure for C++ QuantizedTensor #1251

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec
from ..tensor import Float8Tensor


__all__ = [
"general_gemm",
"gemm",
"fp8_gemm",
"grouped_gemm",
Expand All @@ -25,6 +27,143 @@ def _empty_tensor() -> torch.Tensor:
return torch.Tensor()


def general_gemm(
A: Union[torch.Tensor, Float8Tensor],
B: Union[torch.Tensor, Float8Tensor],
workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Copy link
Collaborator

@timmoon10 timmoon10 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're reworking this API, perhaps we should call it matmul since it's less ambiguous (e.g. with column-major/row-major order). We should also keep the core API simple like torch.matmul and np.matmul, and leave our non-standard options as kwargs:

def matmul(
    A: torch.Tensor,  # maybe QuantizedTensor
    B: torch.Tensor,  # maybe QuantizedTensor
    /,
    out: Optional[torch.Tensor] = None,  # maybe QuantizedTensor
    *,
    transa: bool = False,
    transb: bool = False,
    out_dtype: Optional[tex.DType] = None,
    accumulate_out: bool = False,  # alternatively: alpha and beta
    bias: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,  # more general than gelu
    workspace: torch.Tensor,  # maybe allocate in C++ if not provided
    use_split_accumulator: bool = False,  # maybe hide within cublas_options kwarg?
    userbuffers_options: Optional[dict] = None,  # minimize impact of unstable UB API
) -> torch.Tensor:  # maybe QuantizedTensor

"""GEMM supporting fp8 inputs."""

empty_tensor = _empty_tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
raise ValueError("FP8 output not supported")
# assert_dim_for_fp8_exec(A)
# assert_dim_for_fp8_exec(B)

if out is not None:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
# if gelu:
# gelu_input = torch.empty_like(out, dtype=bias_dtype)
# else:
# gelu_input = empty_tensor
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[A.dtype] if D_dtype is None else D_dtype

args = (
A,
True, # transa
B,
False, # transb
out,
None, # if out_index is None else fp8_meta_tensor.scale[out_index],
out_dtype,
None, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias,
bias_dtype,
gelu,
False, # grad
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)
fn = tex.te_gemm2
if ub_algo is not None:
raise ValueError("Not implemented yet!")
assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(
args
+ (
1,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(
args
+ (
0,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS requires extra output tensor"
args = tuple(
args
+ (
True,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor"
args = tuple(
args
+ (
True,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "ATOMIC_GEMM_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
gelu_input = empty_tensor
else:
out, gelu_input = fn(*args)

return out, gelu_input


def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ enum FP8BwdTensors {
GRAD_INPUT3 = 5
};

class Float8Tensor {
public:
at::Tensor data;
std::optional<at::Tensor> transpose = std::nullopt;
at::Tensor scale_inv;
DType dtype;
};

} // namespace transformer_engine

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
Expand Down
16 changes: 15 additions & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

#include "common.h"
#include "common/common.h"

/***************************************************************************************************
* Permutation
Expand Down Expand Up @@ -138,6 +137,21 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
* GEMM
**************************************************************************************************/

using MaybeTensor = std::optional<at::Tensor>;

std::vector<at::Tensor> te_gemm2(transformer_engine::Float8Tensor A, bool transa,
transformer_engine::Float8Tensor B, bool transb, MaybeTensor D,
MaybeTensor D_scale, transformer_engine::DType D_type,
MaybeTensor D_amax, MaybeTensor bias,
transformer_engine::DType bias_type, bool gelu, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator);
std::vector<at::Tensor> te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb,
MaybeTensor D, MaybeTensor D_scale,
transformer_engine::DType D_type, MaybeTensor D_amax,
MaybeTensor bias, transformer_engine::DType bias_type, bool gelu,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator);
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
bool transa, at::Tensor B, at::Tensor B_scale_inverse,
transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/

#include "common/common.h"
#include "extensions.h"

constexpr int block_size = 512;
Expand Down
110 changes: 110 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,118 @@
* See LICENSE for license information.
************************************************************************/

#include <optional>

#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pytorch/csrc/common.h"
#include "transformer_engine/transformer_engine.h"

namespace {

void* get_data_ptr(MaybeTensor tensor) {
if (tensor.has_value()) return tensor->data_ptr();
return nullptr;
}

size_t get_size(MaybeTensor tensor, int dim) {
if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim));
return 0;
}

} // namespace

std::vector<at::Tensor> te_gemm2_helper(
at::Tensor A, transformer_engine::DType A_dtype, MaybeTensor A_scale_inv, bool transa,
at::Tensor B, transformer_engine::DType B_dtype, MaybeTensor B_scale_inv, bool transb,
MaybeTensor D, MaybeTensor D_scale, transformer_engine::DType D_type, MaybeTensor D_amax,
MaybeTensor bias, transformer_engine::DType bias_type, bool gelu, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator) {
using namespace transformer_engine;
if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) {
at::Tensor out;
if (D.has_value() && D->data_ptr() != nullptr && !accumulate) {
D->zero_();
out = *D;
} else {
out = at::Tensor(); // TODO: Handle D without a value
}
return {out, at::Tensor()};
}

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs

const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

A = A.contiguous();
B = B.contiguous();

if (!D.has_value()) {
auto type = GetATenDType(D_type);
auto opts = at::TensorOptions().dtype(type).device(A.options().device());
*D = at::empty({B.size(0), A.size(0)}, opts);
}

auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_dtype,
nullptr, nullptr, get_data_ptr(A_scale_inv));
auto te_B = makeTransformerEngineTensor(
B.data_ptr(), {static_cast<size_t>(B.size(0)), static_cast<size_t>(B.size(1))}, B_dtype,
nullptr, nullptr, get_data_ptr(B_scale_inv));
auto te_D = makeTransformerEngineTensor(
D->data_ptr(), {static_cast<size_t>(D->size(0)), static_cast<size_t>(D->size(1))}, D_type,
get_data_ptr(D_amax), get_data_ptr(D_scale), nullptr);
auto te_bias = makeTransformerEngineTensor(get_data_ptr(bias), {get_size(bias, 0)}, bias_type);

at::Tensor pre_gelu_out;
if (gelu) {
auto dtype = GetATenDType(bias_type);
auto opts = A.options().dtype(dtype);
pre_gelu_out = at::empty_like(*D, opts);
}
const auto gelu_shape = gelu ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))}
: std::vector<size_t>{0};
auto te_pre_gelu_out = makeTransformerEngineTensor(
pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type()));
auto te_workspace =
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte);

nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(),
transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator,
num_math_sms, at::cuda::getCurrentCUDAStream());

return {*D, pre_gelu_out};
}

std::vector<at::Tensor> te_gemm2(transformer_engine::Float8Tensor A, bool transa,
transformer_engine::Float8Tensor B, bool transb, MaybeTensor D,
MaybeTensor D_scale, transformer_engine::DType D_type,
MaybeTensor D_amax, MaybeTensor bias,
transformer_engine::DType bias_type, bool gelu, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator) {
return te_gemm2_helper(A.data, A.dtype, A.scale_inv, transa, B.data, B.dtype, B.scale_inv, transb,
D, D_scale, D_type, D_amax, bias, bias_type, gelu, grad, workspace,
workspaceSize, accumulate, use_split_accumulator);
}

std::vector<at::Tensor> te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb,
MaybeTensor D, MaybeTensor D_scale,
transformer_engine::DType D_type, MaybeTensor D_amax,
MaybeTensor bias, transformer_engine::DType bias_type, bool gelu,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator) {
transformer_engine::DType A_dtype = GetTransformerEngineDType(A.scalar_type());
transformer_engine::DType B_dtype = GetTransformerEngineDType(B.scalar_type());
return te_gemm2_helper(A, A_dtype, std::nullopt, transa, B, B_dtype, std::nullopt, transb, D,
D_scale, D_type, D_amax, bias, bias_type, gelu, grad, workspace,
workspaceSize, accumulate, use_split_accumulator);
}

void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
bool transa, at::Tensor B, at::Tensor B_scale_inverse,
Expand Down
Loading