Skip to content

Conversation

@sdh1014
Copy link
Contributor

@sdh1014 sdh1014 commented Nov 9, 2025

1. Overview

  1. Count the sizes and frequencies of all linalg.matmul ops appearing in build/examples/BuddyDeepSeekR1/subgraph0_prefill.mlir and build/examples/BuddyDeepSeekR1/subgraph0_decode.mlir.
  2. Use the existing three matrix vectorization passes to test the performance of linalg.matmul across different dimensions.

2. Experiment Environment

Hardware Configuration

  • CPU: Intel Xeon Platinum 8575C
  • Total Logical CPUs:192(2 sockets × 48 physical cores × SMT2)
  • Frequency: max 4.0 GHz

Software Environment

  • Project: examples/BuddyNext/next-linalg-matmul-vec-perf
  • Compilation Command: make next-linalg-matmul-vec-perf-run

3. Statistics of linalg.matmul op sizes

The following table summarizes the sizes and occurrences of linalg.matmul ops found in subgraph0_prefill.mlir and subgraph0_decode.mlir.

File M N K Counts
subgraph0_prefill.mlir 1024 256 1536 56
1024 1536 1536 56
1024 1536 8960 28
1024 8960 1536 56
1024 151936 1536 1
subgraph0_decode.mlir 1 256 1536 56
1 1536 1536 56
1 1536 8960 28
1 8960 1536 56
1 151936 1536 1

4. Comparison of Different Passes

In examples/BuddyNext/next-linalg-matmul-perf.mlir, the 10 different linalg.matmul op sizes were tested with three optimization passes:

  • -matmul-parallel-vectorization-optimize
  • -matmul-vectorization
  • -matmul-vectorization-blis

The execution times were recorded for each case.

Multi-thread Comparison

The table below shows execution times (in seconds) of each linalg.matmul op under different matrix vectorization passes with multi-threading enabled.

M N K -matmul-parallel-vectorization-optimize -matmul-vectorization -matmul-vectorization-blis -matmul-vectorization-decode
1024 256 1536 0.237258 0.00485735 0.021553 N/A
1024 1536 1536 0.268647 0.0228972 0.0349704 N/A
1024 1536 8960 4.12543 0.0456994 0.135663 N/A
1024 8960 1536 0.549615 0.0125531 0.0147408 N/A
1024 151936 1536 0.563397 0.451447 0.0639093 N/A
1 256 1536 0.000183105 0.000132418 0.000973797 0.000794983
1 1536 1536 0.000267029 0.000900793 0.000775242 0.0119066
1 1536 8960 0.00134206 0.00518198 0.00440283 0.0675052
1 8960 1536 0.000965643 0.00532303 0.00139279 0.0243661
1 151936 1536 0.0157417 0.0994385 0.0175184 0.343831

Single-thread Comparison

In the multi-thread version, the passes include parallelization optimizations. To isolate the effect of each matmul vectorization pass, the parallelization-related passes were removed for a single-thread comparison.

M N K -matmul-parallel-vectorization-optimize -matmul-vectorization -matmul-vectorization-blis
1024 256 1536 0.117767 0.0109147 0.00782201
1024 1536 1536 0.744654 0.126912 0.0525613
1024 1536 8960 2.58169 0.580122 0.329709
1024 8960 1536 2.98968 0.584138 0.303229
1024 151936 1536 29.7125 28.2054 4.6284
1 256 1536 0.0000550747 0.0000288785 0.000478864
1 1536 1536 0.000435829 0.000406116 0.00323114
1 1536 8960 0.00604701 0.00360799 0.0191221
1 8960 1536 0.00317121 0.0023239 0.0190809
1 151936 1536 0.08087611 0.0837886 0.339117

The final version examples/BuddyNext/next-matmul-linalg-bench.mlir uses the multi-threaded parallel version.

Test Screenshots

Due to the large number of screenshots, only the case of -matmul-parallel-vectorization-optimize for the [1024,1536,8960] size—where long execution time was observed—is shown below.
Although all results were obtained using the same pass, the timing varies significantly. The upper row shows the expected result.

image

5. Conclusion

  1. The following table lists the best-performing pass for each linalg.matmul op size under multi-threaded parallel conditions.
M N K Best perf pass
1024 256 1536 matmul-vectorization
1024 1536 1536 matmul-vectorization ≈ matmul-vectorization-blis
1024 1536 8960 matmul-vectorization
1024 8960 1536 matmul-vectorization ≈ matmul-vectorization-blis
1024 151936 1536 matmul-vectorization-blis
1 256 1536 matmul-parallel-vectorization-optimize≈matmul-vectorization
1 1536 1536 matmul-parallel-vectorization-optimize
1 1536 8960 matmul-parallel-vectorization-optimize
1 8960 1536 matmul-parallel-vectorization-optimize
1 151936 1536 matmul-parallel-vectorization-optimize
  1. The matmul-parallel-vectorization-optimize pass shows abnormally long execution time at size [1024,1536,8960].
  2. Applying parallel passes with multi-threading significantly reduces runtime, especially in multi-token scenarios.

@zhanghb97
Copy link
Member

@sdh1014 Good comparison

  • Please resolve code conflicts
  • When comparing multi-threading, call the kernel function multiple times in the main function (the first execution has thread initialization overhead)
  • Now a special pass has been added for the decode stage, which can be included in the comparison at decode size: -matmul-vectorization-decode

@sdh1014 sdh1014 force-pushed the feat/buddynext-linalg-matmul-vec-perf branch from 3b32a9e to abbd002 Compare November 17, 2025 12:32
@sdh1014
Copy link
Contributor Author

sdh1014 commented Nov 17, 2025

ok, I have now made the following changes:

  • When collecting timing metrics for linalg.matmul op, I added a warmup phase and averaged over multiple iterations to reduce the randomness of a single measurement.
  • Added -matmul-vectorization-decode, and updated the comparison of quantization-pass time at the prefill and decode stages for different matrix sizes under multithreaded settings.
M N K -matmul-parallel-vectorization-optimize -matmul-vectorization -matmul-vectorization-blis -matmul-vectorization-decode
1024 256 1536 0.237258 0.00485735 0.021553 N/A
1024 1536 1536 0.268647 0.0228972 0.0349704 N/A
1024 1536 8960 4.12543 0.0456994 0.135663 N/A
1024 8960 1536 0.549615 0.0125531 0.0147408 N/A
1024 151936 1536 0.563397 0.451447 0.0639093 N/A
1 256 1536 0.000183105 0.000132418 0.000973797 0.000794983
1 1536 1536 0.000267029 0.000900793 0.000775242 0.0119066
1 1536 8960 0.00134206 0.00518198 0.00440283 0.0675052
1 8960 1536 0.000965643 0.00532303 0.00139279 0.0243661
1 151936 1536 0.0157417 0.0994385 0.0175184 0.343831

However, the statistics do not show better performance in the decode stage.

  • Also updated the relevant content in the PR comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants