Skip to content

Commit 35e9c12

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
and
Varun Sundar Rabindranath
authored
[Kernel] Tuned int8 Cutlass Kernels for SM75 (T4) (vllm-project#6996)
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 93548eb commit 35e9c12

File tree

3 files changed

+135
-12
lines changed

3 files changed

+135
-12
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
112112
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
113113

114114
timers = []
115-
# pytorch impl
115+
# pytorch impl - bfloat16
116116
timers.append(
117117
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
118118
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
119119
torch.bfloat16, label, sub_label, pytorch_mm_impl,
120120
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
121121

122+
# pytorch impl - float16
123+
timers.append(
124+
bench_fn(a.to(dtype=torch.float16, device="cuda"),
125+
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
126+
torch.float16, label, sub_label, pytorch_mm_impl,
127+
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
128+
122129
# cutlass impl
123130
timers.append(
124131
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "cutlass/cutlass.h"
44

55
#include "scaled_mm_c2x.cuh"
6+
#include "scaled_mm_c2x_sm75_dispatch.cuh"
67
#include "scaled_mm_c2x_sm80_dispatch.cuh"
78
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
89
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
@@ -20,21 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
2021
TORCH_CHECK(a.dtype() == torch::kInt8);
2122
TORCH_CHECK(b.dtype() == torch::kInt8);
2223

23-
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
24-
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
25-
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
26-
2724
if (out.dtype() == torch::kBFloat16) {
28-
return vllm::cutlass_gemm_caller<
29-
vllm::cutlass_2x_gemm<cutlass::arch::Sm75, vllm::enable_sm75_to_sm80,
30-
int8_t, cutlass::bfloat16_t, Epilogue, TileShape,
31-
WarpShape, InstructionShape, 2>>(
25+
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
26+
Epilogue>(
3227
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
3328
} else {
3429
TORCH_CHECK(out.dtype() == torch::kFloat16);
35-
return vllm::cutlass_gemm_caller<vllm::cutlass_2x_gemm<
36-
cutlass::arch::Sm75, vllm::enable_sm75_to_sm80, int8_t, cutlass::half_t,
37-
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
30+
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
3831
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
3932
}
4033
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#pragma once
2+
3+
#include "scaled_mm_c2x.cuh"
4+
5+
/**
6+
* This file defines Gemm kernel configurations for SM75 based on the Gemm
7+
* shape.
8+
*/
9+
10+
namespace vllm {
11+
12+
template <typename InType, typename OutType,
13+
template <typename, typename> typename Epilogue>
14+
struct sm75_config_default {
15+
// This config is used in 2 cases,
16+
// - M in (256, inf]
17+
// - M in (64, 128]
18+
// Shared memory required by this Gemm 32768
19+
static_assert(std::is_same<InType, int8_t>());
20+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
21+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
22+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
23+
using Cutlass2xGemm =
24+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
25+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
26+
};
27+
28+
template <typename InType, typename OutType,
29+
template <typename, typename> typename Epilogue>
30+
struct sm75_config_M256 {
31+
// M in (128, 256]
32+
// Shared memory required by this Gemm 65536
33+
static_assert(std::is_same<InType, int8_t>());
34+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
35+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
36+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
37+
using Cutlass2xGemm =
38+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
39+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
40+
};
41+
42+
template <typename InType, typename OutType,
43+
template <typename, typename> typename Epilogue>
44+
struct sm75_config_M64 {
45+
// M in (32, 64]
46+
// Shared memory required by this Gemm 49152
47+
static_assert(std::is_same<InType, int8_t>());
48+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
49+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
50+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
51+
using Cutlass2xGemm =
52+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
53+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
54+
};
55+
56+
template <typename InType, typename OutType,
57+
template <typename, typename> typename Epilogue>
58+
struct sm75_config_M32 {
59+
// M in [1, 32]
60+
// Shared memory required by this Gemm 49152
61+
static_assert(std::is_same<InType, int8_t>());
62+
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
63+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
64+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
65+
using Cutlass2xGemm =
66+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
67+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
68+
};
69+
70+
template <typename InType, typename OutType,
71+
template <typename, typename> typename Epilogue,
72+
typename... EpilogueArgs>
73+
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
74+
torch::Tensor const& a,
75+
torch::Tensor const& b,
76+
EpilogueArgs&&... args) {
77+
static_assert(std::is_same<InType, int8_t>());
78+
TORCH_CHECK(a.dtype() == torch::kInt8);
79+
TORCH_CHECK(b.dtype() == torch::kInt8);
80+
81+
using Cutlass2xGemmDefault =
82+
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
83+
using Cutlass2xGemmM256 =
84+
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
85+
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
86+
using Cutlass2xGemmM64 =
87+
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
88+
using Cutlass2xGemmM32 =
89+
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
90+
91+
// Due to shared memory requirements, some Gemms may fail to run on some
92+
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
93+
// in such cases.
94+
// sm75_config_default has the least shared-memory requirements.
95+
using FallbackGemm = Cutlass2xGemmDefault;
96+
97+
uint32_t const m = a.size(0);
98+
uint32_t const mp2 =
99+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
100+
if (mp2 <= 32) {
101+
// M in [1, 32]
102+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
103+
out, a, b, std::forward<EpilogueArgs>(args)...);
104+
} else if (mp2 <= 64) {
105+
// M in (32, 64]
106+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
107+
out, a, b, std::forward<EpilogueArgs>(args)...);
108+
} else if (mp2 <= 128) {
109+
// M in (64, 128]
110+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
111+
out, a, b, std::forward<EpilogueArgs>(args)...);
112+
} else if (mp2 <= 256) {
113+
// M in (128, 256]
114+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
115+
out, a, b, std::forward<EpilogueArgs>(args)...);
116+
} else {
117+
// M in (256, inf)
118+
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
119+
out, a, b, std::forward<EpilogueArgs>(args)...);
120+
}
121+
}
122+
123+
} // namespace vllm

0 commit comments

Comments
 (0)