Skip to content

Commit 72c2ae3

Browse files
authored
Update attention template (#30)
Ths commit updates the attention template to include promote operands and decomposition_config. --------- Signed-off-by: Manupa Karunaratne <[email protected]>
1 parent 4621947 commit 72c2ae3

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

.github/workflows/run_bench.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ concurrency:
1616

1717
jobs:
1818
benchmark:
19-
runs-on: mi300-kernel
19+
runs-on: mi300-sdxl-kernel
2020

2121
steps:
2222
- name: "Checkout Repo"

attentionbench/attention_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def get_lowering_config(self) -> str:
7070
f"#iree_gpu.lowering_config<"
7171
+ "{ "
7272
+ f"workgroup = [{', '.join(map(str, self.wg_tiles))}], "
73-
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}]"
73+
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}],"
74+
+ f"promote_operands = [0, 1, 2]"
7475
+ " }"
7576
+ f">"
7677
)
@@ -93,7 +94,7 @@ def get_translation_info(self) -> str:
9394
return (
9495
f"#iree_codegen.translation_info<"
9596
+ f"LLVMGPUVectorDistribute"
96-
+ f" workgroup_size = [{self.N_warp * 64}, {self.M_warp}]"
97+
+ f" workgroup_size = [{self.N_warp * self.M_warp * 64}]"
9798
+ f" subgroup_size = 64"
9899
+ f" ,{{mma_schedule = {self.get_mma_schedule()}"
99100
+ f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
@@ -137,6 +138,10 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
137138
%empty = tensor.empty() : !O
138139
%O = iree_linalg_ext.attention
139140
{{ indexing_maps = [#Q, #K, #V, #S, #O]
141+
,decomposition_config = {{
142+
qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}},
143+
pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}}
144+
}}
140145
{",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""}
141146
}}
142147
ins(%Q, %K, %V, %scale : !Q, !K, !V, !dtype) outs(%empty : !O) {{

0 commit comments

Comments
 (0)