Skip to content

[Do Not Merge] - LoRA V1 Reference PR #11613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
burstiness: float = 1.0,
lora_requests: Optional[List[str]] = None,
) -> AsyncGenerator[Tuple[str, int, int], None]:
"""
Asynchronously generates requests at a specified rate
Expand All @@ -390,14 +391,19 @@ async def get_request(
(burstiness > 1) results in a more uniform arrival of requests.
"""
input_requests = iter(input_requests)
if lora_requests:
lora_requests = iter(lora_requests)

# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
theta = 1.0 / (request_rate * burstiness)

for request in input_requests:
yield request
if lora_requests:
yield request, next(lora_requests)
else:
yield request, None

if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
Expand Down Expand Up @@ -537,6 +543,7 @@ async def benchmark(
ignore_eos: bool,
goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int],
lora_requests: Optional[List[str]] = None,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
Expand Down Expand Up @@ -614,10 +621,13 @@ async def limited_request_func(request_func_input, pbar):

benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
async for request, lora_request in get_request(input_requests,
request_rate, burstiness,
lora_requests):
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id,
model_name=model_name,
request_func_input = RequestFuncInput(model=lora_request or model_id,
model_name=lora_request
or model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
Expand Down Expand Up @@ -874,6 +884,13 @@ def main(args: argparse.Namespace):

goodput_config_dict = check_goodput_args(args)

#input_requests: List[Tuple[str, int, int]],
lora_requests = None
if args.lora_models:
lora_requests = [
random.choice(args.lora_models) for _ in range(len(input_requests))
]

# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()
Expand All @@ -900,6 +917,7 @@ def main(args: argparse.Namespace):
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_requests=lora_requests,
))

# Save config and results to json
Expand Down Expand Up @@ -1230,6 +1248,12 @@ def main(args: argparse.Namespace):
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')

parser.add_argument(
'--lora-models',
nargs='+',
default=[],
)

parser.add_argument("--served-model-name",
type=str,
default=None,
Expand Down
80 changes: 66 additions & 14 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import dataclasses
import json
import pickle
import random
import time
from functools import cache
Expand All @@ -26,6 +27,9 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators

SAMPLING_TEMPERATURE = 0.0
SAMPLING_TOP_P = 1.0


@dataclasses.dataclass
class SampleRequest:
Expand Down Expand Up @@ -166,6 +170,7 @@ def run_vllm(
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
Expand All @@ -180,8 +185,8 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
temperature=SAMPLING_TEMPERATURE,
top_p=SAMPLING_TOP_P,
ignore_eos=True,
max_tokens=request.expected_output_len,
))
Expand All @@ -191,13 +196,23 @@ def run_vllm(

use_beam_search = False

outputs = None
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
if do_profile:
llm.start_profile()
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()

if do_profile:
llm.stop_profile()
# it takes a while to generate the profile !!
print("Called llm.stop_profile() ... Sleeping for 100s on client "
"side for profile trace dump to finish !!")
time.sleep(100)
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
Expand All @@ -214,14 +229,15 @@ def run_vllm(
ignore_eos=True,
))
end = time.perf_counter()
return end - start
return end - start, outputs


async def run_vllm_async(
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
do_profile: bool = False,
) -> float:
from vllm import SamplingParams

Expand All @@ -239,14 +255,16 @@ async def run_vllm_async(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
temperature=SAMPLING_TEMPERATURE,
top_p=SAMPLING_TOP_P,
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests.append(request.lora_request)

generators = []
if do_profile:
await llm.start_profile()
start = time.perf_counter()
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
Expand All @@ -256,10 +274,25 @@ async def run_vllm_async(
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
outputs_dict = {}
async for i, res in all_gens:
pass
outputs_dict[i] = res

end = time.perf_counter()
return end - start
elapsed = end - start

if do_profile:
await llm.stop_profile()
print("Called llm.stop_profile() ... Sleeping for 100s on client"
"side for profile trace dump to finish !!")
time.sleep(100)

num_prompts = len(prompts)
outputs = []
for i in range(num_prompts):
outputs.append(outputs_dict[i])

return elapsed, outputs


def run_hf(
Expand Down Expand Up @@ -392,16 +425,25 @@ def main(args: argparse.Namespace):
for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
elapsed_time, outputs = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
do_profile=args.profile,
))
else:
elapsed_time = run_vllm(requests, args.n,
EngineArgs.from_cli_args(args))
elapsed_time, outputs = run_vllm(requests,
args.n,
EngineArgs.from_cli_args(args),
do_profile=args.profile)

if args.pickle_outputs:
print("Pickling request outputs : ")
with open("outputs.pkl", "wb+") as f:
pickle.dump(outputs, f)

elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -491,6 +533,16 @@ def main(args: argparse.Namespace):
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")

parser.add_argument("--profile",
action='store_true',
default=False,
help="Profile the entire run")

parser.add_argument("--pickle-outputs",
action="store_true",
default=False,
help="Pickle outputs got from benchmark")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand Down
17 changes: 17 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[False, True])
def run_with_both_engines_lora(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
8 changes: 8 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand Down
10 changes: 10 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -47,6 +49,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
Loading