Skip to content

Commit

Permalink
Initial grouped gemm integration. Builds, but untested.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Sep 20, 2023
1 parent 417692a commit 96bc988
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*~
build
*.egg-info
dist
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass
214 changes: 214 additions & 0 deletions csrc/grouped_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
#include "grouped_gemm.h"

#include <c10/util/BFloat16.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"

namespace grouped_gemm {

#define CUDA_CALL(code) \
do { \
cudaError_t status = code; \
std::string err = cudaGetErrorString(status); \
TORCH_CHECK(status == cudaSuccess, err); \
} while (0)

using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped<
// Non-transposed A operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
::cutlass::ComplexTransform::kNone,
8,
// Non-transposed B operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
::cutlass::ComplexTransform::kNone,
8,
// C operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
float,
::cutlass::arch::OpClassTensorOp,
// TODO(tgale): Update this to support SM90.
::cutlass::arch::Sm80,
::cutlass::gemm::GemmShape<128, 128, 32>,
::cutlass::gemm::GemmShape<64, 64, 32>,
::cutlass::gemm::GemmShape<16, 8, 16>,
::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
// TODO(tgale): Experiment with GroupScheduleMode.
4>::GemmKernel;
using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernelNN>;

std::vector<cutlass::gemm::GemmCoord> MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) {
const size_t num_experts = batch_sizes.size(0);
const size_t k = b.size(1), n = b.size(2);
std::vector<cutlass::gemm::GemmCoord> problem_sizes(num_experts);
for (int i = 0; i < num_experts; ++i) {
problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr<int64_t>()[i], n, k);
}
return problem_sizes;
}

template <typename T>
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
size_t bytes = x.size() * sizeof(T);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
torch::Tensor out = torch::empty(bytes, options);

CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
x.data(), bytes,
cudaMemcpyHostToDevice,
c10::cuda::getCurrentCUDAStream()));
return out;
}

template <typename Gemm>
typename Gemm::Arguments MakeArguments(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
auto problem_sizes_host = MakeProblemSizes(b, batch_sizes);

// Calculate the number of threadblocks to use and validate the result.
int64_t num_experts = problem_sizes_host.size();

// NOTE: This is borrowed from FasterTransformer.
int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts);
if (!threadblock_count) {
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
}

// Create the host arrays of leading dimension data and pointer data.
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;

std::vector<int64_t> lda_host(num_experts), offsets_a(num_experts);
std::vector<int64_t> ldb_host(num_experts), offsets_b(num_experts);
std::vector<int64_t> ldc_host(num_experts), offsets_c(num_experts);
int64_t elements_a = 0, elements_b = 0, elements_c = 0;

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
std::vector<ElementA *> ptr_a_host(num_experts);
std::vector<ElementB *> ptr_b_host(num_experts);
std::vector<ElementC *> ptr_c_host(num_experts);

for (int i = 0; i < num_experts; ++i) {
auto problem = problem_sizes_host[i];
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);

offsets_a[i] = elements_a;
offsets_b[i] = elements_b;
offsets_c[i] = elements_c;

ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i];
ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i];
ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i];

elements_a += problem.m() * problem.k();
elements_b += problem.k() * problem.n();
elements_c += problem.m() * problem.n();
}

// Copy the problem sizes, pointers and leading dimension data to the device.
torch::Tensor lda = CopyToDevice(lda_host, a.device());
torch::Tensor ldb = CopyToDevice(ldb_host, a.device());
torch::Tensor ldc = CopyToDevice(ldc_host, a.device());
torch::Tensor ptr_a = CopyToDevice(ptr_a_host, a.device());
torch::Tensor ptr_b = CopyToDevice(ptr_b_host, a.device());
torch::Tensor ptr_c = CopyToDevice(ptr_c_host, a.device());
torch::Tensor problem_sizes = CopyToDevice(problem_sizes_host, a.device());

typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)problem_sizes.data_ptr(),
(int)num_experts,
(int)threadblock_count,
epilogue_op,
(ElementA**)ptr_a.data_ptr(),
(ElementB**)ptr_b.data_ptr(),
(ElementC**)ptr_c.data_ptr(),
(ElementC**)ptr_c.data_ptr(),
/*lda=*/(int64_t*)lda.data_ptr(),
/*ldb=*/(int64_t*)ldb.data_ptr(),
/*ldc=*/(int64_t*)ldc.data_ptr(),
/*ldd=*/(int64_t*)ldc.data_ptr(),
(cutlass::gemm::GemmCoord*)problem_sizes_host.data());
return arguments;
}


// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
// assumed to be batched with fixed sized batches.
//
// TODO(tgale): Validate alignment is true for every batch element.
torch::Tensor GroupedGemm(torch::Tensor a, torch::Tensor b, torch::Tensor batch_sizes) {
// We expected a CUDA tensor with two dimensions and shape
// (tokens, hidden_in) for 'a'.
TORCH_CHECK(a.is_cuda());
TORCH_CHECK(a.ndimension() == 2);
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);

// We expected a CUDA tensor with three dimensions and shape
// (num_experts, hidden_in, hidden_out) for 'b'.
TORCH_CHECK(b.is_cuda());
TORCH_CHECK(b.ndimension() == 3);
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);

// We expect the batch_sizes on CPU.
TORCH_CHECK(batch_sizes.is_cpu());
TORCH_CHECK(batch_sizes.ndimension() == 1);
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);

// Validate the contraction dimensions match.
int64_t tokens = a.size(0), hidden_in = a.size(1);
int64_t num_experts = b.size(0), hidden_out = b.size(2);
TORCH_CHECK(hidden_in == b.size(1));

// Validate that we have one size per expert.
TORCH_CHECK(batch_sizes.size(0) == num_experts);

// Allocate the output.
auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(a.device());
torch::Tensor c = torch::empty({tokens, hidden_out}, options);

// TODO(tgale): Support fused transposition.
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(b.is_contiguous());

using Gemm = GemmGroupedNN;
Gemm gemm;

auto arguments = MakeArguments<Gemm>(a, b, c, batch_sizes);
int64_t workspace_size = gemm.get_workspace_size(arguments);
options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
torch::Tensor workspace = torch::empty(workspace_size, options);

// Initialize the kernel.
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
}

// Execute the kernel in the current stream.
if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
}

// Return the output tensor.
return c;
}

} // namespace grouped_gemm
7 changes: 7 additions & 0 deletions csrc/grouped_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <torch/extension.h>

namespace grouped_gemm {

torch::Tensor GroupedGemm(torch::Tensor a, torch::Tensor b, torch::Tensor batch_sizes);

} // namespace grouped_gemm
11 changes: 11 additions & 0 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "grouped_gemm.h"

#include <torch/extension.h>

namespace grouped_gemm {

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grouped_gemm", &GroupedGemm, "Grouped GEMM.");
}

} // namespace grouped_gemm
52 changes: 52 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from pathlib import Path
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"

cwd = Path(os.path.dirname(os.path.abspath(__file__)))
_dc = torch.cuda.get_device_capability()
_dc = f"{_dc[0]}{_dc[1]}"

ext_modules = [
CUDAExtension(
"grouped_gemm_backend",
["csrc/ops.cu", "csrc/grouped_gemm.cu"],
include_dirs = [
f"{cwd}/third_party/cutlass/include/"
],
extra_compile_args={
"cxx": [
"-fopenmp", "-fPIC", "-Wno-strict-aliasing"
],
"nvcc": [
f"--generate-code=arch=compute_{_dc},code=sm_{_dc}",
# NOTE: CUTLASS requires c++17.
"-std=c++17",
],
}
)
]

setup(
name="grouped_gemm",
version="0.0.1",
author="Trevor Gale",
author_email="[email protected]",
description="GEMM Grouped",
url="https://github.com/tgale06/grouped_gemm",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
packages=find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
install_requires=["absl-py", "numpy", "torch"],
)
1 change: 1 addition & 0 deletions third_party/cutlass
Submodule cutlass added at 8783c4

0 comments on commit 96bc988

Please sign in to comment.