Skip to content

Commit c0409aa

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Add FlexAttention (#2443)
Summary: ``` + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 0.24512 0.236768 0.257984 (16, 16, 1024, 128) 0.442944 0.41968 0.419008 (8, 16, 2048, 128) 0.845728 0.798688 0.74368 (4, 16, 4096, 128) 1.65574 1.55882 1.4041 (2, 16, 8192, 128) 3.27904 3.08669 2.73846 (1, 16, 16384, 128) 6.55098 6.14246 5.38931 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --causal (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 0.199136 0.187424 0.201632 (16, 16, 1024, 128) 0.298208 0.278048 0.28912 (8, 16, 2048, 128) 0.51504 0.481088 0.46528 (4, 16, 4096, 128) 0.9584 0.890688 0.82928 (2, 16, 8192, 128) 1.84317 1.70605 1.55763 (1, 16, 16384, 128) 3.62157 3.34694 3.02374 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 1.36323 1.30051 0.94192 (16, 16, 1024, 128) 1.89187 1.80678 1.40486 (8, 16, 2048, 128) 2.93325 2.83165 2.33082 (4, 16, 4096, 128) 5.05456 4.91002 4.19229 (2, 16, 8192, 128) 9.34131 9.10381 7.87952 (1, 16, 16384, 128) 17.9824 17.5658 15.4029 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd --causal (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 1.14022 1.07926 0.838688 (16, 16, 1024, 128) 1.41398 1.33376 1.06554 (8, 16, 2048, 128) 1.97437 1.87725 1.53197 (4, 16, 4096, 128) 3.08925 2.95606 2.48413 (2, 16, 8192, 128) 5.28237 5.08138 4.37834 (1, 16, 16384, 128) 9.27056 8.94672 8.24531 ``` Pull Request resolved: #2443 Reviewed By: xuzhao9 Differential Revision: D61924897 Pulled By: bertmaher fbshipit-source-id: 40e18e6bee2b91e4a9826f5056a431950ee3495d
1 parent 31d07c6 commit c0409aa

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

torchbenchmark/operators/flash_attention/operator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,25 @@ def sdpa_flash_attention(q, k, v):
358358
v,
359359
)
360360

361+
@register_benchmark()
362+
def flex_attention(self, q, k, v):
363+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
364+
365+
def causal_mask(b, h, q_idx, kv_idx):
366+
return q_idx >= kv_idx
367+
368+
flex_attention = torch.compile(flex_attention, dynamic=False)
369+
370+
if self.causal:
371+
B, H, S, D = q.shape
372+
block_mask = create_block_mask(
373+
causal_mask, B=None, H=None, Q_LEN=S, KV_LEN=S
374+
)
375+
else:
376+
block_mask = None
377+
378+
return lambda: flex_attention(q, k, v, block_mask=block_mask)
379+
361380
@register_metric()
362381
def tflops(
363382
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics

0 commit comments

Comments
 (0)