Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend #15655

Open
wants to merge 186 commits into
base: main
Choose a base branch
from

Conversation

Akshat-Tripathi
Copy link
Contributor

@Akshat-Tripathi Akshat-Tripathi commented Mar 27, 2025

Summary

This PR optimises the Multi-LoRA implementation from #14238. This one should be merged in after it.

This includes several kernel optimisations:

  • Block size tuning 2bb8868 d7338f8
  • Faster mask creation 2aacb34
  • Allowing for some blocks to be skipped 6ee0b57
  • Adding LoRA Laning eb804a0
  • Splitting the Pallas kernel into shrink/expand variants de6746a
  • Removing masking when only 1 LoRA adapter is used aad109b

And a few general ones:

  • Pre-transposing the LoRA adapters used in the expand op a82f3fe
  • Reducing recompilations 5638e7d

Things left/RFC

  • There are still a few recompilations at the start of a run that I need to track down
  • LogitsProcessorWithLoRA introduces a long (~1.5 second) stall when it's enabled, but not much activity seems to happen on the CPU or TPU during this time. I've disabled this for now.
  • It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?
  • I have microbenchmarks for the kernels, but I'm not sure what the right place to put them is.

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
…` to be called with infinities

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
@mgoin mgoin self-requested a review April 3, 2025 09:26
Copy link

mergify bot commented Apr 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 3, 2025
@Akshat-Tripathi Akshat-Tripathi force-pushed the tpu_bgmv_optimisation branch from eac3b95 to 49157b1 Compare April 4, 2025 13:02
@Akshat-Tripathi
Copy link
Contributor Author

I've got some performance numbers using the MLPerf Llama2-70B inference benchmark. I retokenised the dataset for Llama3.1.

Model Parameters Without LoRA (tok/s) With LoRA (tok/s)
Llama3.1 8B 1621.62 1426.01
Llama3.1 70B 432.964 326.709

@psyhtest
Copy link

psyhtest commented Apr 7, 2025

Benchmarking LoRA against baseline (no LoRA) throughput

We use NVIDIA's GenAI-Perf tool to force fixed-length inputs and outputs to produce "heatmap" plots as below. On TPU-v6e and H100 instances, we vary the inputs from 128 to 8k. On L4 instances, we vary the inputs from 128 to 2k.

We calculate the LoRA slowdown as ((LoRA throughput / baseline throughput) - 1) * 100%.

Llama3.1-8B

1x TPU-v6e

The LoRA slowdown varies from -8.4% to -23.9%.

Llama3 1-8B_1xTPU-v6e

1x GPU-L4

The LoRA slowdown varies from -17.3% to -32.8%.

Llama3 1-8B_1xGPU-L4_v2

1x GPU-H100

The LoRA slowdown varies from -10.0% to -51.8%.

Llama3 1-8B_1xGPU-H100

Llama3.1-70B

8x TPU-v6e

The LoRA slowdown varies from -20.7% to -46.3%.

Llama3 1-70B_8xTPU-v6e

8x GPU-L4

The LoRA slowdown varies from -13.8% (second best: -25.1%) to -49.7%.

Llama3 1-70B_8xGPU-L4

4x GPU-H100

Unable to launch VMs due to persistent unavailability across multiple zones and regions.

@mergify mergify bot removed the needs-rebase label Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants