-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend #15655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Akshat-Tripathi
wants to merge
202
commits into
vllm-project:main
Choose a base branch
from
krai:tpu_bgmv_optimisation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,397
−53
Open
Changes from all commits
Commits
Show all changes
202 commits
Select commit
Hold shift + click to select a range
d993de9
Added non-triton SGMV and BGMV ops (not kernels yet)
Akshat-Tripathi 4f816ed
Made a copy of the layer tests for the TPU. TODO: DRY it out
Akshat-Tripathi 5f0355b
Removed extra print
Akshat-Tripathi edd02c5
Made some minor shape-based fixes to the kernels
Akshat-Tripathi aff94f9
Added basic lora execution code
Akshat-Tripathi adfd194
Replaced einsums with matmuls+reshaping for better xla compilation
Akshat-Tripathi 816a56c
Replaced inf/-inf with max/min since XLA doesn't allow `nan_to_num_()…
Akshat-Tripathi c8a51c8
Added lora config to `_dummy_run()`
Akshat-Tripathi 51f929d
Changed torch._dynamo config
Akshat-Tripathi 23d4a24
Quick patch to allow non lora code to run
Akshat-Tripathi 47397a7
Minor fixes
Akshat-Tripathi 456eb37
Replaced einsums with matmuls to allow xla compilation
Akshat-Tripathi eabc748
Removed xla ops for torch ops
Akshat-Tripathi ac9753e
Removed old debug log points
Akshat-Tripathi aa8b0fd
Fixed bgmv/sgmv shape error
Akshat-Tripathi 124215f
Fixed lora batching crash in warmup
Akshat-Tripathi e148254
Fixed shape issue in add_lora_linear()
Akshat-Tripathi 494b35e
Fixed dynamic lora tensor shapes
Akshat-Tripathi 1dbfcd9
Fixed lora_input preparation for actual execution
Akshat-Tripathi 1bb2578
Fixed wrong model bug
Akshat-Tripathi ddc4cbc
Moved if statements outside of for loops in PunicaWrapperTPU
Akshat-Tripathi 48a6944
Added early exits to PunicaWrapperTPU lora functions
Akshat-Tripathi 7802e84
Added torch ops for tpu (Static prefill sizes)
Akshat-Tripathi ab5396b
XLA bgmv operations are now imported from the default torch_ops
Akshat-Tripathi fdf29d3
Removed TODOs
Akshat-Tripathi c2b4139
Removed old code
Akshat-Tripathi f31b7d1
Linting
Akshat-Tripathi 87ff73e
Fixed import error
Akshat-Tripathi 96c3dde
lint
Akshat-Tripathi 4e72ede
Abstracted out infinity values
Akshat-Tripathi e4d35ce
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi 3cf0680
Removed total_size for linting
Akshat-Tripathi a8ab0c9
Reverted changes to torch_ops
Akshat-Tripathi d73f1ce
Lint
Akshat-Tripathi e01d9a4
Replaced in-place buffer updates with direct returning
Akshat-Tripathi 0c1bfb9
PunicaWrapperTPU now returns unchanged buffer if no loras are needed
Akshat-Tripathi 46ce7fa
Simplified TPU prefill
Akshat-Tripathi 5d0cc37
Removed sgmv kernels from TPU implementation
Akshat-Tripathi 7590b0e
Fix bug
Akshat-Tripathi e7f75b5
Added torch.compiles to PunicaWrapperTPU functions
Akshat-Tripathi fe193f7
Replaced "x[x==-1] = y" with "x = torch.where(x == - 1, y)"
Akshat-Tripathi 52e3911
Revert "Added torch.compiles to PunicaWrapperTPU functions"
Akshat-Tripathi 33a70b0
Fix linting
Akshat-Tripathi 67446b2
Added lora hotswapping test
Akshat-Tripathi 0db19b1
Fixed hotswapping test prompt
Akshat-Tripathi a4c3b0a
Fixed bug in tpu lora test
Akshat-Tripathi 9d6c388
Merged set_no_lora() functionality with _udpate_prefill_metada
Akshat-Tripathi 2a9978e
Added Multi-LoRA functionality to TPU V1
Akshat-Tripathi b8c65bc
Added test that verifies switching
Akshat-Tripathi 942ef07
Added bgmv kernel test code
Akshat-Tripathi 56529b9
Added some dynamic lora selection
Akshat-Tripathi 735073f
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi 1067b50
Added bgmv kernel test
Akshat-Tripathi d897f87
Made bgmv kernel fully functional (WIP on supporting smaller ranks) (…
Akshat-Tripathi d6eca29
Updated bgmv_kernel to work with ranks that aren't exact multiples of…
Akshat-Tripathi d97aae5
Removed interpreted mode on kernel
Akshat-Tripathi 3ac0f63
Added pallas kernel benchmarking script
Akshat-Tripathi a620e58
Fixed mosaic kernel compilation issue
Akshat-Tripathi 00d6dfd
Added reference kernel benchmarking
Akshat-Tripathi fb0601d
Registered the custom op
Akshat-Tripathi 89b062e
Integrated bgmv kernel
Akshat-Tripathi ef2ef8c
Fixed model compilation bugs
Akshat-Tripathi a79e19d
Minor changes
Akshat-Tripathi cc8cdf6
Removed scratch files
Akshat-Tripathi ad8c565
Minor pallas kernel fixes
Akshat-Tripathi 8d83065
integrate ragged paged attn v2
yaochengji dea7d02
fix precompile
yaochengji 0cf0eaa
Merge branch 'chengji/ragged_attn_v2_new' into multi_lora_tpu_v1
Akshat-Tripathi 6249307
Fixed padding issue with v1
Akshat-Tripathi af0a6a9
Added temporary patch over pallas kernel routing bug
Akshat-Tripathi 264d36a
Updated kernel test
Akshat-Tripathi b725c6a
Lint
Akshat-Tripathi 038465c
Removed duplicate method
Akshat-Tripathi 2004369
Lint
Akshat-Tripathi 71a1cdd
More linting
Akshat-Tripathi 3dba9e0
Linting
Akshat-Tripathi f7f95e4
Lint
Akshat-Tripathi adfdcdb
Fixed bug related to consecutive pallas kernels
Akshat-Tripathi a6d5c01
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi 5a27785
Removed v0 TPU LoRA implementation
Akshat-Tripathi 5d15fbc
Fixed VocabParallelEmbeddingWithLoRA compilation error
Akshat-Tripathi ca3d810
Fixed LogitsProcessorWithLoRA layer compilation issue
Akshat-Tripathi 12f71ce
Slightly sped up the kernel
Akshat-Tripathi d040ee8
Lint
Akshat-Tripathi e696144
Fixed bug with higher batch sizes
Akshat-Tripathi d110613
Lint
Akshat-Tripathi f8d5da2
Removed TODO in bgmv pallas test
Akshat-Tripathi d114377
Fixed PunicaWrapperBase typing
Akshat-Tripathi 430bae9
Fixed bug where vLLM crashes on decode
Akshat-Tripathi fb36fd6
Fixed NaN bug with LogitsProcessor
Akshat-Tripathi c454062
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi 23b14d1
Updated LoRALogitsProcessor to work with the TPU
Akshat-Tripathi 27d6f70
Lint
Akshat-Tripathi b547271
Fixed batched logits processing
Akshat-Tripathi f5138b8
Updated kernel test
Akshat-Tripathi ad14872
Added kernel benchmark (dev only, remove later)
Akshat-Tripathi 7418b5a
Tuned bgmv kernel block sizes
Akshat-Tripathi 2aacb34
Improved lora output masking
Akshat-Tripathi 6ee0b57
Skipped matmuls where no loras are needed
Akshat-Tripathi d9e415f
Renamed variables for better readabiity
Akshat-Tripathi 460e808
Moved inner loop into grid spec
Akshat-Tripathi 12ac3b8
Revert "Moved inner loop into grid spec"
Akshat-Tripathi 1bb152f
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi af15bd1
Added comment
Akshat-Tripathi 41555d1
Lint
Akshat-Tripathi 4ac7aa9
Added a fused shrink/expand kernel
Akshat-Tripathi 9f5a497
Revert "Added a fused shrink/expand kernel"
Akshat-Tripathi 54344b7
Added some autotuning for kernels
Akshat-Tripathi c5a42e2
Renamed padding variables
Akshat-Tripathi e66067c
Used a static ones vector, gives a 5%ish perf boost
Akshat-Tripathi 7c79683
Restricted block sizes to prevent memory from blowing up
Akshat-Tripathi d7338f8
Removed larger lora/dim block sizes since they reduce perf outside of…
Akshat-Tripathi 2bb8868
Allowed smaller LoRA blocks if necessary
Akshat-Tripathi 27ad793
Replaced torch.cat operations with F.pad
Akshat-Tripathi a82f3fe
Added fused lora transpose [experimental]
Akshat-Tripathi 640420b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi de6746a
Separated bgmv_shrink and bgmv_expand kernels to avoid unneccessary d…
Akshat-Tripathi 19b9089
Removed redundant branch
Akshat-Tripathi a02d0e9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi e07d6fb
Moved punica related `mark_dynamic` to the TPUModelRunner to allow th…
Akshat-Tripathi 5b4ba1b
Moved `maybe_dummy_run_with_lora` to the `_dummy_run` method
Akshat-Tripathi efbdc62
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi b64dc31
Lint
Akshat-Tripathi 49a8102
Minor fixes + lint
Akshat-Tripathi c1be5f9
Lint
Akshat-Tripathi bf44d65
Fixed mark_dynamic placement for eager/compiled modes
Akshat-Tripathi 15ff074
Fixed mark_dynamic placement for eager/compiled modes
Akshat-Tripathi d9f89b6
Temporary fix to LogitsProcessorWithLoRA pipeline bubble issue
Akshat-Tripathi 81775d3
Sampler is now compiled with LoRA
Akshat-Tripathi ab036e0
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi 2bc00b8
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi 829028d
Removed early exits since they cause eager execution
Akshat-Tripathi b6af323
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi df69c52
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi 5638e7d
Removed some recompilations when updating LoRA metadata
Akshat-Tripathi bae61a2
Aligned lora codepath with recompilation fixes
Akshat-Tripathi dc8b940
Disabled add_lora_logits temporarily
Akshat-Tripathi eb804a0
Added the LoRA Laning optimisation + tests + explanation
Akshat-Tripathi fbb902a
Updated kernel benchmarking script with lora laning
Akshat-Tripathi 8ba2749
Added error for when someone tries to use LoRA adapters on the V0 TPU…
Akshat-Tripathi 51d87a5
Added test to buildkite
Akshat-Tripathi bf52dbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi fce044a
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 8b1dae8
Lint
Akshat-Tripathi aad109b
Optimised single lora kernels
Akshat-Tripathi b09d595
Fixed compilation bug
Akshat-Tripathi 151fde4
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi a1df8c8
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 72d95c6
Fixed LoRA Laning bug
Akshat-Tripathi be0915c
Fixed extra recompilations
Akshat-Tripathi 478a8bb
Lint
Akshat-Tripathi 4178e58
Lint
Akshat-Tripathi 8a3009d
Added type annotation to lora_output
Akshat-Tripathi 493e73f
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 1d6085a
Removed unused function/parameter
Akshat-Tripathi f208234
Removed redundant padding in kernel for larger lora/dim sizes
Akshat-Tripathi ec0e181
Moved xm.mark_step() calls to move appropriate places
Akshat-Tripathi 38de473
Reduced number of graphs compiled
Akshat-Tripathi 8dabfab
Fixed memory usage problem
Akshat-Tripathi f5949a7
Lint
Akshat-Tripathi a7ae288
Lint
Akshat-Tripathi 2e67aa8
Removed first inference recompilation
Akshat-Tripathi 27b3c52
Fixed more recompilations
Akshat-Tripathi d1452af
Added flag to disabled add_lora_logits()
Akshat-Tripathi 1cc89a5
Lint
Akshat-Tripathi 93d3e8f
Fixed performance issue where the sampler would face long stalls
Akshat-Tripathi 9fb50b9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi 592c62f
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi e1aaed6
Fixed laning integration bug
Akshat-Tripathi 62500e1
Lint
Akshat-Tripathi eb72ab6
Removed LoRA vocab padding for TPU
Akshat-Tripathi 49157b1
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 54db22d
Fixed 0 padding issue with LoRA
Akshat-Tripathi 5232785
Changed TPU lora_vocab_padding_size to 1
Akshat-Tripathi 1b4c2f2
Fixed bug in bgmv_expand kernel - outputs weren't being written with …
Akshat-Tripathi c8f68d7
Changed TPU lora_vocab_padding_size to 1
Akshat-Tripathi ed3b245
Enabled lora bias
Akshat-Tripathi 4855791
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 9d35414
Replaced `enable_laning` flag with dim comparison
Akshat-Tripathi 54c00c3
Enabled fully sharded loras
Akshat-Tripathi f3e48a6
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi a4b2e27
Removed test benchmarking file
Akshat-Tripathi 9f0fdbe
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi fbddd3c
Refactored add_shrink to return a tensor not a tuple
Akshat-Tripathi 2012bbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi d1c11c8
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi 1803135
Removed tuple return in add_shrink()
Akshat-Tripathi 0eeb72c
Removed extra compilation
Akshat-Tripathi c1be9fe
Replaced copies with buffer donation to reduce memory usage
Akshat-Tripathi de5da33
Added explicit compilation in add_lora
Akshat-Tripathi 5adc67f
Removed LoRA ID collision
Akshat-Tripathi 342ff8b
Fix pre-commit
Akshat-Tripathi fc65edb
Reduced number of iterations in test_lora
Akshat-Tripathi 2f1da29
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi 7daaafa
Lint
Akshat-Tripathi 893ac04
Reduced pallas kernel test size
Akshat-Tripathi 2a0fce7
Added/removed comments
Akshat-Tripathi 4d42844
Fixed pallas kernel test
Akshat-Tripathi 50a06fc
Made LoRA e2e test more robust
Akshat-Tripathi 9b78b74
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi ca68ce6
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi e91774a
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import pytest | ||
import torch | ||
|
||
# Required to register the custom ops | ||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import | ||
|
||
N_TOKENS = [16, 1024, 4096] | ||
HIDDEN_SIZES = [1024, 2048, 4096] | ||
|
||
DTYPES = [torch.float16] | ||
NUM_LORA = [1, 4, 16] | ||
RANKS = [32, 256, 512] | ||
|
||
|
||
def generate_test_data(T, D, L, N, seed, dtype=torch.float32): | ||
""" | ||
Inputs: (All integers) | ||
T: Total number of tokens | ||
D: Input dim | ||
L: LoRA Dim | ||
N: N LoRAs | ||
|
||
Outputs: | ||
inputs: torch.Tensor - shape (T, D) | ||
loras: torch.Tensor - shape (N, 1, L, D) | ||
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) | ||
|
||
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T | ||
""" | ||
torch.manual_seed(seed) | ||
|
||
inputs = torch.randn((T, D), device="xla", dtype=dtype) | ||
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) | ||
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") | ||
|
||
ref_output = ref_bgmv(inputs, loras, idxs) | ||
return inputs, loras, idxs, ref_output | ||
|
||
|
||
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): | ||
selected_loras = loras[idxs] | ||
if len(selected_loras.shape) == 4: | ||
selected_loras = selected_loras.squeeze(axis=1) | ||
|
||
batch_size, output_size, input_size = selected_loras.shape | ||
return (selected_loras @ inputs.reshape( | ||
(batch_size, input_size, 1))).reshape((batch_size, output_size)) | ||
|
||
|
||
# Parameterize tests with various shapes and dtypes | ||
@pytest.mark.parametrize("T", N_TOKENS) | ||
@pytest.mark.parametrize("D", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("L", RANKS) | ||
@pytest.mark.parametrize("N", NUM_LORA) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("op_type", ["shrink", "expand"]) | ||
@pytest.mark.parametrize("seed", [0]) | ||
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): | ||
if op_type == "expand": | ||
D, L = L, D | ||
|
||
inputs, loras, idxs, ref_output = generate_test_data( | ||
T, D, L, N, seed, dtype) | ||
|
||
# Run bgmv | ||
if op_type == "shrink": | ||
output = torch.ops.xla.bgmv_shrink(inputs, loras, idxs) | ||
else: | ||
output = torch.ops.xla.bgmv_expand(inputs, loras.transpose(2, 3), idxs) | ||
|
||
# Make sure we have no NaNs | ||
assert not torch.any(torch.isnan(output)) | ||
|
||
# Compare with reference output | ||
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) | ||
|
||
|
||
# Parameterize tests with various shapes and dtypes | ||
@pytest.mark.parametrize("T", N_TOKENS) | ||
@pytest.mark.parametrize("D", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("L", RANKS) | ||
@pytest.mark.parametrize("N", NUM_LORA) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", [0]) | ||
def test_lora_laning_correctness(T, D, L, N, dtype, seed): | ||
inputs, loras_a, idxs, _ = generate_test_data(T, D, L, N, seed, dtype) | ||
_, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype) | ||
|
||
r1 = ref_bgmv(inputs, loras_a, idxs) | ||
r2 = ref_bgmv(r1, loras_b, idxs) | ||
|
||
o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs) | ||
o2 = torch.ops.xla.bgmv_expand(o1, loras_b.transpose(2, 3), idxs) | ||
|
||
# Compare with reference output | ||
assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import pytest | ||
|
||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
|
||
|
||
@pytest.fixture(scope="function", autouse=True) | ||
def use_v1_only(monkeypatch: pytest.MonkeyPatch): | ||
""" | ||
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 | ||
for all tests in this file | ||
""" | ||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
yield | ||
|
||
|
||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) | ||
def test_lora_e2e(num_loras: int): | ||
""" | ||
This test ensures that we can run with LoRA adapters on the TPU backend. | ||
It verifies multiple capabilities: | ||
1. We can compile a model with LoRA adapters enabled | ||
2. We can run <num_loras> LoRA adapters | ||
3. We receive correct outputs when running with multiple LoRA adapters | ||
4. We can swap LoRA adapters between host and device | ||
""" | ||
lora_name_template = \ | ||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" | ||
lora_requests = [ | ||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) | ||
for i in range(1, 5) | ||
] | ||
|
||
llm = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", | ||
num_scheduler_steps=1, | ||
max_model_len=256, | ||
max_seq_len_to_capture=256, | ||
max_num_seqs=8, | ||
enable_lora=True, | ||
max_loras=num_loras, | ||
max_lora_rank=8) | ||
|
||
prompt = "What is 1+1? \n" | ||
|
||
for _ in range(2): | ||
for i, req in enumerate(lora_requests): | ||
output = llm.generate(prompt, | ||
sampling_params=vllm.SamplingParams( | ||
max_tokens=256, temperature=0), | ||
lora_request=req)[0].outputs[0].text | ||
assert int(output.strip()[0]) == i + 1 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, | ||
bgmv_shrink) | ||
from vllm.lora.ops.xla_ops.pallas import LORA_RANK_BLOCK_SIZE | ||
|
||
__all__ = [ | ||
"bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "LORA_RANK_BLOCK_SIZE" | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we know lora helps generate the expect result?
I guess we can do something similar to
vllm/tests/lora/test_mixtral.py
Line 63 in 7b5ecf7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah no we don't need to, the adapters make the model give incorrect answers to "What is 1+1?". So the nth adapter will make it answer n instead of 2 (except the 2nd adapter).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can we add some comments to explain that?