Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TK Wave kernels to attention benchmark for attention-v transpose #44

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add TK Wave kernels to attention benchmark
Signed-off-by: Stanley Winata <[email protected]>
raikonenfnu committed Jan 29, 2025
commit 0a9526afb30f639bf0c4b79f0a63a3565036a387
7 changes: 7 additions & 0 deletions .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
@@ -47,6 +47,11 @@ jobs:
source bench_venv/bin/activate
python attentionbench/attention_bench.py
- name: TK Attention
run: |
source bench_venv/bin/activate
python attentionbench/attention_bench.py --tk
- name: TK GEMM
run: |
source bench_venv/bin/activate
@@ -66,6 +71,8 @@ jobs:
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_f16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png
20 changes: 16 additions & 4 deletions attentionbench/attention_bench.py
Original file line number Diff line number Diff line change
@@ -10,12 +10,17 @@
from utils import *
from attention_utils import *
from problems import get_attention_configs
from wave_attention_utils import compile_wave_attention_config


def compile_attention(tag, config, kernel_dir, vmfb_dir):
def compile_attention_iree(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_attention_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)

def compile_attention_wave(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_wave_attention_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Config file updater.")
@@ -36,6 +41,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
parser.add_argument("--model", help="roofline on certain model", default=None)
parser.add_argument('--tk', help="Run conv kernels using Wave Kernels", action=argparse.BooleanOptionalAction)

args = parser.parse_args()
logging.basicConfig(level=args.log_level)
@@ -63,6 +69,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
)
compile_attention = compile_attention_wave if args.tk else compile_attention_iree
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_attention, list(compile_args))))

@@ -80,7 +87,8 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):

results = []
index = 0
output_csv = "results/iree_attention.csv"
output_csv = "results/iree_attention_tk.csv" if args.tk else "results/iree_attention.csv"
entrypoint = "isolated_benchmark" if args.tk else "main"
csv_dir = os.path.dirname(output_csv)
if not os.path.exists(csv_dir):
os.makedirs(csv_dir)
@@ -98,13 +106,17 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--function=main",
f"--function={entrypoint}",
"--benchmark_repetitions=3",
f"--input={query_shape}",
f"--input={key_shape}",
f"--input={value_shape}",
"--benchmark_repetitions=3",
]

if args.tk:
out_shape = config.get_output_shape()
exec_args.append(f"--input={out_shape}")

# iree benchmark kernels
ret_value, cmd_out, cmd_err = run_iree_command(exec_args)
ok = ret_value == 0
16 changes: 13 additions & 3 deletions attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,16 @@ def get_pv_intrinsic(intrinsic: IntrinsicType):
case _:
return intrinsic

def get_32_bit_type(input_type: str):
assert isinstance(input_type, str)
match input_type[0]:
case "f":
return "f32"
case "i":
return "i32"
case _:
raise NotImplementedError("Unexpected type to obtain 32 bit type on attention utils.")

@dataclass
class AttentionConfig:
B: int
@@ -82,10 +92,10 @@ def get_key_shape(self) -> str:
return f"{self.B}x{self.K2}x{self.K1}x{self.dtype}"

def get_value_shape(self) -> str:
return f"{self.B}x{self.K2}x{self.N}x{self.dtype}"
return f"{self.B}x{self.N}x{self.K2}x{self.dtype}"

def get_output_shape(self) -> str:
return f"{self.B}x{self.M}x{self.N}x{self.dtype}"
return f"{self.B}x{self.M}x{self.N}x{get_32_bit_type(self.dtype)}"

def get_byte_count(self) -> int:
dtype_bits_map = {
@@ -198,7 +208,7 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
attn_kernel = f"""
#Q = affine_map<(b, m, n, k1, k2) -> (b, m, k1)>
#K = affine_map<(b, m, n, k1, k2) -> (b, k2, k1)>
#V = affine_map<(b, m, n, k1, k2) -> (b, k2, n)>
#V = affine_map<(b, m, n, k1, k2) -> (b, n, k2)>
#S = affine_map<(b, m, n, k1, k2) -> ()>
#O = affine_map<(b, m, n, k1, k2) -> (b, m, n)>
244 changes: 244 additions & 0 deletions attentionbench/wave_attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from utils import *
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from attention_utils import AttentionConfig
import traceback

try:
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.utils import (
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
)
except ImportError:
TURBINE_AVAILABLE = False
else:
TURBINE_AVAILABLE = True

@dataclass
class AttentionShape:
num_query_heads: int
num_kv_heads: int
head_size: int
head_size_kv: int
# -----------------------
# Prefill specific
num_seqs: Optional[int] = None
max_seq_len: Optional[int] = None
total_seq_len: Optional[int] = None
# -----------------------
# Vanilla attention
query_seq_len: Optional[int] = None
kv_seq_len: Optional[int] = None

def get_vanilla_attention_kernel(
shape: AttentionShape, mfma_variant: MMAType, dynamic_dims: bool, input_dtype: "dtype"
):
# Input sizes
B = tkl.sym.B
M = tkl.sym.M
N = tkl.sym.N
K1 = tkl.sym.K1
K2 = tkl.sym.K2
# Workgroup tile sizes
BLOCK_B = tkl.sym.BLOCK_B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K2 = tkl.sym.BLOCK_K2
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK")
LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV")
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)]

if mfma_variant[1] == MMAType.F32_16x16x16_F16 or mfma_variant[1] == MMAType.F32_16x16x32_F8:
Mvec = 16
Nvec = 16
if mfma_variant[1] == MMAType.F32_32x32x8_F16 or mfma_variant[1] == MMAType.F32_32x32x16_F8:
Mvec = 32
Nvec = 32

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(4, 1, 1),
mma_type=mfma_variant[1],
vector_shapes={B: 0, M: Mvec, N: Nvec},
)
]

if dynamic_dims:
constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(
num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j}
)

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, input_dtype],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, input_dtype],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, input_dtype],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, M, tkl.f32](0.0)
init_max = tkl.Register[B, M, tkl.f32](-1e6)

# This microkernel encodes the fact that if the reduction
# dimension were tiled, then we would need to materialize a loop.
@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
def repeat(
partial_max: tkl.Register[B, M, tkl.f32],
partial_sum: tkl.Register[B, M, tkl.f32],
acc: tkl.Register[B, N, M, tkl.f32],
):
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
e_delta = tkw.exp2(x_j - m_j)
e_init = partial_sum * e_delta_max
d_j = tkw.sum(e_delta, e_init, dim=K2)
imm_f16 = tkw.cast(e_delta, input_dtype)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
return m_j, d_j, acc

# repeat represents the results of the loop
res_max, res_sum, res_mm = repeat
reciprocal_sum = tkw.reciprocal(res_sum)
res = res_mm * reciprocal_sum
tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]),
LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]),
BLOCK_B: 1,
BLOCK_M: 128,
BLOCK_N: 64,
BLOCK_K2: 64,
B: shape.num_query_heads,
M: shape.query_seq_len,
N: shape.head_size_kv,
K1: shape.head_size,
K2: shape.kv_seq_len,
}

dynamic_symbols = []
dynamic_symbols_map = {}
if dynamic_dims:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols_map[B] = hyperparams[B]
dynamic_symbols_map[K2] = hyperparams[K2]
dynamic_symbols.append(M)
dynamic_symbols.append(N)
dynamic_symbols.append(B)
dynamic_symbols.append(K2)
del hyperparams[M]
del hyperparams[N]
del hyperparams[B]
del hyperparams[K2]

return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map


def compile_wave_attention_config(
config: AttentionConfig, kernel_dir: Path, vmfb_dir: Path
) -> tuple[Path, Optional[Path]]:
if not TURBINE_AVAILABLE:
raise ValueError("iree.turbine package is not available")

mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")

try:
_compile_attention(config, mlir_file, vmfb_file)
except Exception as e:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(str(e))
f.write(traceback.format_exc())
return mlir_file, None, None

return mlir_file, vmfb_file


def _convert_dtype(dtype: str):
dtypes = {
"i8": tkl.i8,
"i16": tkl.i16,
"i32": tkl.i32,
"i64": tkl.i64,
"f8E4M3FNUZ": tkl.f8e4m3fnuz,
"f16": tkl.f16,
"f32": tkl.f32,
"f64": tkl.f64,
"bf16": tkl.bf16,
}
return dtypes[dtype]


def _compile_attention(config: AttentionConfig, mlir_file: Path, vmfb_file: Path):
shape = AttentionShape(
num_query_heads=config.B,
num_kv_heads=config.B,
query_seq_len=config.M,
head_size_kv=config.N,
head_size=config.K1,
kv_seq_len=config.K2,
)

input_dtype = _convert_dtype(config.dtype)
if input_dtype == tkl.f16:
mfma_variant = (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16)
elif input_dtype == tkl.f8e4m3fnuz:
mfma_variant = (MMAType.F32_32x32x16_F8, MMAType.F32_32x32x16_F8)
else:
raise NotImplementedError(f"Got {config.dtype}, TK attention currently only support f8E4M3FNUZ and f16.")

base_attention, hyperparams, _, _ = get_vanilla_attention_kernel(
shape, mfma_variant, False, input_dtype
)

# config = get_default_run_config()
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
schedule=False,
inline=False,
):
mod = base_attention().module_op # This will generate vmfb file
with open(mlir_file, "w") as f:
f.write(str(mod))

print(f"Successfully compiled to {vmfb_file}")