Skip to content

Commit cba4060

Browse files
committed
[Examples] appy flash attention op for decode phrase
1 parent 26168fa commit cba4060

File tree

3 files changed

+77
-79
lines changed

3 files changed

+77
-79
lines changed

examples/BuddyDeepSeekR1/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ add_custom_command(
264264

265265
add_custom_command(
266266
OUTPUT subgraph-f16.o
267-
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0-f16.mlir
267+
COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0-f16.mlir
268+
-simplify-tosa-reshape |
269+
${LLVM_TOOLS_BINARY_DIR}/mlir-opt
268270
-pass-pipeline ${TOSA_PIPELINE} |
269271
${BUDDY_BINARY_DIR}/buddy-opt
270272
-eliminate-empty-tensors
@@ -277,7 +279,8 @@ add_custom_command(
277279
-ownership-based-buffer-deallocation
278280
-buffer-deallocation-simplification
279281
-bufferization-lower-deallocations
280-
-matmul-parallel-vectorization-optimize
282+
-assume-tight-memref-layout
283+
-matmul-vectorization-blis
281284
-batchmatmul-optimize
282285
-canonicalize
283286
-cse
@@ -292,7 +295,7 @@ add_custom_command(
292295
-canonicalize
293296
-cse
294297
-sccp
295-
-remove-dead-values
298+
# -remove-dead-values
296299
-memref-expand
297300
-arith-expand
298301
-convert-vector-to-llvm

examples/BuddyDeepSeekR1/import-deepseek-r1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
simply_fuse,
3939
apply_classic_fusion,
4040
eliminate_transpose,
41+
flash_attention,
4142
)
4243
from buddy.compiler.graph.type import DeviceType
4344
from buddy.compiler.graph.operation import *
@@ -175,10 +176,11 @@
175176
params = dynamo_compiler_prefill.imported_params[graph_prefill]
176177
graphs_prefill[0].perform([eliminate_transpose])
177178
graphs_decode[0].perform([eliminate_transpose])
178-
pattern_list = [simply_fuse]
179+
pattern_list_prefill = [simply_fuse]
180+
pattern_list_decode = [simply_fuse,flash_attention]
179181

180-
graphs_prefill[0].fuse_ops(pattern_list)
181-
graphs_decode[0].fuse_ops(pattern_list)
182+
graphs_prefill[0].fuse_ops(pattern_list_prefill)
183+
graphs_decode[0].fuse_ops(pattern_list_decode)
182184

183185
graph_prefill.op_groups["subgraph0_prefill"] = graph_prefill.op_groups.pop(
184186
"subgraph0"

0 commit comments

Comments
 (0)