Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions benchmarks/ascend/bench_k_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import triton
from dlblas.utils.device_utils import infer_device
from dlblas.kernels.ascend.grouped_gemm import k_grouped_gemm
from tests.kernels.ascend.test_grouped_matmul import generate_random_list, k_grouped_matmul_torch

if __name__=='__main__':
groups = 8
z = groups
DEV = infer_device()
dtype_ = torch.bfloat16
batch_sizes = torch.Tensor(generate_random_list(groups, groups*2560)).to(DEV).to(torch.int64).abs()
K = batch_sizes.sum().item()
for (M, N) in ((4096, 4096), (512, 512), (768*2, 2048), (2048, 768), (1536*2, 4096)):
a = torch.randn(K, M, dtype = dtype_, device = DEV)
b = torch.randn(K, N, dtype = dtype_, device = DEV)
golden = k_grouped_matmul_torch(a, b, batch_sizes.cpu())
result = k_grouped_gemm(a, b, batch_sizes)
mask = golden.abs() < 1.0
tmpatol = tmprtol = 2 ** -6
torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0)
torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol)
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['cnt'], # Argument names to use as an x-axis for the plot
# x_vals=[128 * i for i in range(10, 15)], # Different possible values for `x_name`
x_vals=[1], # NOTE: the tunning framework specialized to one shape
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=['triton_gmm' , 'torch'] , # Label name for the lines
line_names=['Triton_gmm', 'Torch'] , # Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel='TFLOPS', # Label name for the y-axis
plot_name='k_grouped_matmul-performance-' +
(f'bf16-[Batch={z} M={M} N={N} k={K}]'), # Name for the plot, used also as a file name for saving the plot.
args={},
))
@triton.testing.perf_report(configs)
def benchmark(cnt, provider):
warmup = 500
rep = 500
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: k_grouped_matmul_torch(a, b, batch_sizes),
quantiles=quantiles,
warmup=warmup,
rep=rep)
if provider == 'triton_gmm':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: k_grouped_gemm(a, b, batch_sizes),
quantiles=quantiles,
warmup=warmup,
rep=rep)

return ms, max_ms, min_ms

benchmark.run(show_plots=False, print_data=True)
print("run matmul success")
98 changes: 50 additions & 48 deletions benchmarks/ascend/bench_m_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,62 @@
import torch
import triton
from dlblas.utils.device_utils import infer_device
from dlblas.kernels.ascend.m_grouped_gemm import m_grouped_gemm
from tests.kernels.ascend.test_m_grouped_matmul import generate_random_list, torch_grouped_matmul
from dlblas.kernels.ascend.grouped_gemm import m_grouped_gemm
from tests.kernels.ascend.test_grouped_matmul import generate_random_list, m_grouped_matmul_torch

if __name__=='__main__':
groups = 128
groups = 8
z = groups
trans_b = False; print(f"{trans_b = }")
device = infer_device()
batch_sizes = torch.Tensor(generate_random_list(groups, groups*2560)).to(device).to(torch.int64)
M = batch_sizes.sum().item()
for (n, k) in ((4096, 4096), (512, 512), (768*2, 2048), (2048, 768), (1536*2, 4096)): # (4096, 1536)
for (n, k) in ((4096, 4096), (512, 512), (768*2, 2048), (2048, 768), (1536*2, 4096)):
a = torch.randn(M, k, dtype = torch.bfloat16, device = device).view(-1, k)
b = torch.randn(z, n, k, dtype = torch.bfloat16, device = device) if trans_b else torch.randn(z, k, n, dtype = torch.bfloat16, device = device)
print(f"M={M}, z={z}, k={k}, n={n}")
golden = torch_grouped_matmul(a, b, batch_sizes, trans_b)
result = m_grouped_gemm(a, b, batch_sizes, trans_b)
mask = golden.abs() < 1.0
tmpatol = tmprtol = 2 ** -6
torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0)
torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol)
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['cnt'], # Argument names to use as an x-axis for the plot
# x_vals=[128 * i for i in range(10, 15)], # Different possible values for `x_name`
x_vals=[1], # NOTE: the tunning framework specialized to one shape
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=['triton_gmm' , 'torch'] , # Label name for the lines
line_names=['Triton_gmm', 'Torch'] , # Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel='TFLOPS', # Label name for the y-axis
plot_name='m_grouped_matmul-performance-' +
(f'bf16-[Batch={z} M={M} N={n} k={k}]'), # Name for the plot, used also as a file name for saving the plot.
args={},
))
@triton.testing.perf_report(configs)
def benchmark(cnt, provider):
warmup = 500
rep = 500
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_grouped_matmul(a, b, batch_sizes, trans_b),
quantiles=quantiles,
warmup=warmup,
rep=rep)
if provider == 'triton_gmm':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: m_grouped_gemm(a, b, batch_sizes, trans_b),
quantiles=quantiles,
warmup=warmup,
rep=rep)
for trans_b in (True, False):
if trans_b:
b = torch.randn(z, n, k, dtype = torch.bfloat16, device = device)
else:
b = torch.randn(z, k, n, dtype = torch.bfloat16, device = device)
golden = m_grouped_matmul_torch(a, b, batch_sizes, trans_b)
result = m_grouped_gemm(a, b, batch_sizes, trans_b)
mask = golden.abs() < 1.0
tmpatol = tmprtol = 2 ** -6
torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0)
torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol)
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['cnt'], # Argument names to use as an x-axis for the plot
# x_vals=[128 * i for i in range(10, 15)], # Different possible values for `x_name`
x_vals=[1], # NOTE: the tunning framework specialized to one shape
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=['triton_gmm' , 'torch'] , # Label name for the lines
line_names=['Triton_gmm', 'Torch'] , # Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel='TFLOPS', # Label name for the y-axis
plot_name='m_grouped_matmul-performance-' +
(f'bf16-[Batch={z} M={M} N={n} k={k} trans_b={trans_b}]'), # Name for the plot, used also as a file name for saving the plot.
args={},
))
@triton.testing.perf_report(configs)
def benchmark(cnt, provider):
warmup = 500
rep = 500
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: m_grouped_matmul_torch(a, b, batch_sizes, trans_b),
quantiles=quantiles,
warmup=warmup,
rep=rep)
if provider == 'triton_gmm':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: m_grouped_gemm(a, b, batch_sizes, trans_b),
quantiles=quantiles,
warmup=warmup,
rep=rep)

return ms, max_ms, min_ms
return ms, max_ms, min_ms

benchmark.run(show_plots=False, print_data=True)
print("run matmul success")
benchmark.run(show_plots=False, print_data=True)
print("run matmul success")
8 changes: 0 additions & 8 deletions dlblas/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@
import importlib.util
import os

# Manually import the specific submodules you need
try:
# Import grouped_gemm subpackage
from .grouped_gemm.BF16 import *

except ImportError as e:
print(f"Warning: Could not import submodules: {e}")

def import_all_modules_from_folder(folder_path):
"""
dynamically import all Python modules under this folder
Expand Down
3 changes: 3 additions & 0 deletions dlblas/kernels/ascend/grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .grouped_gemm import grouped_gemm_triton
from .m_grouped_gemm import m_grouped_gemm
from .k_grouped_gemm import k_grouped_gemm
35 changes: 35 additions & 0 deletions dlblas/kernels/ascend/grouped_gemm/grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from .m_grouped_gemm import m_grouped_gemm
from .k_grouped_gemm import k_grouped_gemm

class GroupedGemm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, tokens_per_expert):
out = m_grouped_gemm(x, w, tokens_per_expert, trans_b=True)
ctx.save_for_backward(x, w, tokens_per_expert)
return out

@staticmethod
def backward(ctx, grad_output):
x, w, tokens_per_expert = ctx.saved_tensors
dx = m_grouped_gemm(grad_output, w, tokens_per_expert, trans_b=False)
dw = k_grouped_gemm(grad_output, x, tokens_per_expert)
return dx, dw, None


def grouped_gemm_triton(x, w, tokens_per_expert):
"""Grouped matrix multiplication (GMM) for expert models.

Args:
x (Tensor): Input tensor of shape (batch_size, seq_len, din).
w (Tensor): Weight tensor of shape (num_experts, dout, din).
tokens_per_expert (Tensor): Number of tokens per expert.

Returns:
Tensor: Output tensor of shape (batch_size, seq_len, dout).
"""
if x.shape[0] == 0:
# put x and w to the pytorch graph
return torch.matmul(x, w[0].T)
return GroupedGemm.apply(x, w, tokens_per_expert)
105 changes: 105 additions & 0 deletions dlblas/kernels/ascend/grouped_gemm/k_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

import torch
from torch import Tensor
import triton
import triton.language as tl
from dlblas.utils.op_helper import grouped_lanuch_diagonal
from dlblas.utils.device_utils import get_number_cores


def get_autotune_config():
return [
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}),
]

@triton.autotune(configs=get_autotune_config(), key=['M', 'N'])
@triton.jit
def k_grouped_gemm_kernel(
A,
B,
C,
group_size_ptr,
num_groups: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_TRESHHOLD: tl.constexpr,
):
total_cores = tl.num_programs(axis=0)
core_idx = tl.program_id(axis=0)
last_count = 0
group_start = 0
group_end = 0
num_block_m = tl.cdiv(M, BLOCK_M)
num_block_n = tl.cdiv(N, BLOCK_N)
blocks_per_group = num_block_m * num_block_n
# group_size_k = tl.load(group_size_ptr + tl.arange(0, num_groups)).to(tl.int32)
for group_idx in range(num_groups):
# k = tl.extract_slice(group_size_k, [group_idx], [1], [1])
tokens = tl.load(group_size_ptr + group_idx).to(tl.int32)
group_end = group_start + tokens
cur_count = last_count + blocks_per_group
cur_block = core_idx if core_idx >= last_count else (core_idx + total_cores)
while cur_block < cur_count:
task_m_idx, task_n_idx = grouped_lanuch_diagonal(cur_block-last_count, num_block_m, num_block_n, BLOCK_TRESHHOLD)
# matmul begin
offs_am = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = task_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = group_start + tl.arange(0, BLOCK_K)
a_ptrs_base = A + offs_k[:, None]*M + offs_am[None, :]
b_ptrs_base = B + offs_k[:, None]*N + offs_bn[None, :]
msk_m = offs_am < M
msk_n = offs_bn < N
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for kk in tl.range(0, tl.cdiv(tokens, BLOCK_K)):
a_ptrs = a_ptrs_base + kk * BLOCK_K * M
b_ptrs = b_ptrs_base + kk * BLOCK_K * N
a = tl.load(a_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_m[None, :], other=0.0)
aa = tl.trans(a)
tl.compile_hint(aa, "dot_pad_only_k")
b = tl.load(b_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_n[None, :], other=0.0)
tl.compile_hint(b, "dot_pad_only_k")
accumulator = tl.dot(aa, b, acc=accumulator)

c = accumulator.to(C.dtype.element_ty)
offs_cm = group_idx * M + task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = task_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = C + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < (group_idx+1) * M) and (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# matmul_end
cur_block = cur_block + total_cores
last_count = cur_count % total_cores
group_start = group_end


def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor:
assert A.dim() == 2
assert B.dim() == 2
AK, M = A.shape
BK, N = B.shape
assert A.stride(-1) == 1, "Please make sure A is K-major"
assert B.stride(-1) == 1, "Please make sure B is K-major"
assert AK == BK, "Please make sure that A and B have the same seqlen"
num_groups = size_per_group.shape[0]
C = A.new_empty(num_groups, M, N)
num_cores = get_number_cores()

def grid(META):
assert M % META["BLOCK_M"] == 0, "Only support when M is a multiple of BLOCK_M"
return (num_cores, )

k_grouped_gemm_kernel[grid](A, B, C, size_per_group, num_groups, M, N)
# print(f"best config {k_grouped_gemm_kernel.best_config}", flush = True)
return C
Loading