Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Use 2WG pipeline design for MLA implementation on Hopper #952

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Mar 17, 2025

This PR implements #892 .

Per benchmark, 2WG pipeline (FlashMLA's implementation) is faster than our current 3WG pipeline design on Hopper. While it remains investigation where the gap comes from, we should implements the 2WG (and 4WG in the future) pipeline in FlashInfer to make sure our implementation not getting worse performance than flashmla.

Performance

Before this PR:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1547.23 GB/s
FLOPs: 167.29 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1483.82 GB/s
FLOPs: 290.23 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2238.72 GB/s
FLOPs: 242.06 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1612.66 GB/s
FLOPs: 315.43 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2821.32 GB/s
FLOPs: 305.05 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 1767.63 GB/s
FLOPs: 345.74 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1960.50 GB/s
FLOPs: 223.79 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1533.88 GB/s
FLOPs: 331.70 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2546.83 GB/s
FLOPs: 290.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 1629.73 GB/s
FLOPs: 352.43 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2820.22 GB/s
FLOPs: 321.93 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 1657.89 GB/s
FLOPs: 358.52 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2682.98 GB/s
FLOPs: 319.63 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 1600.79 GB/s
FLOPs: 375.94 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 2803.48 GB/s
FLOPs: 333.98 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 1584.79 GB/s
FLOPs: 372.18 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 2768.36 GB/s
FLOPs: 329.80 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 1565.82 GB/s
FLOPs: 367.73 TFLOPs

After this PR:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1520.70 GB/s
FLOPs: 164.42 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1807.33 GB/s
FLOPs: 353.51 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2327.25 GB/s
FLOPs: 251.63 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 2024.00 GB/s
FLOPs: 395.88 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2897.75 GB/s
FLOPs: 313.32 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 2256.69 GB/s
FLOPs: 441.40 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1963.77 GB/s
FLOPs: 224.17 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 2011.81 GB/s
FLOPs: 435.05 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2638.88 GB/s
FLOPs: 301.23 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 2168.25 GB/s
FLOPs: 468.88 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 3008.55 GB/s
FLOPs: 343.43 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 2175.46 GB/s
FLOPs: 470.44 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2724.09 GB/s
FLOPs: 324.52 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 2153.42 GB/s
FLOPs: 505.72 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 3015.30 GB/s
FLOPs: 359.22 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 2120.13 GB/s
FLOPs: 497.91 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 3100.86 GB/s
FLOPs: 369.41 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 2096.94 GB/s
FLOPs: 492.46 TFLOPs

@yzh119 yzh119 changed the title [WIP] Use 2WG pipeline design for MLA implementation on Hopper perf: Use 2WG pipeline design for MLA implementation on Hopper Mar 26, 2025
@yzh119 yzh119 marked this pull request as ready for review March 26, 2025 08:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant