Skip to content

Commit cf27543

Browse files
committed
enhance multi-turn chat feature, Update loadgen dispatcher
- request dispatcher supports assign request to a specific worker - multi-turn chat enhace with load banlanced on both worker and user session level. - introduced to standardize the lazy loading of inference data. This replaces the previous implementation and provides a cleaner, extensible design for data handling between the data generator, load generator, and API data layers.
1 parent e74e507 commit cf27543

File tree

14 files changed

+217
-63
lines changed

14 files changed

+217
-63
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Inference Perf is a GenAI inference performance benchmarking tool that allows yo
2222
* Supports benchmarking large deployments with frameworks like [llm-d](https://llm-d.ai/), [Dynamo](https://docs.nvidia.com/dynamo/latest/) and [Inference Gateway](https://gateway-api-inference-extension.sigs.k8s.io/).
2323
* Supports specifying an exact input and output distribution to simulate different scenarios - Gaussian distribution, fixed length, min-max cases are all supported.
2424
* Generates different load patterns and can benchmark specific cases like burst traffic, scaling to saturation and other autoscaling / routing scenarios.
25+
* Supprots Multi-turn chat conversations, it can keep context of a series of messages to simulate a conversation. A request in each chat round will keep previouse messages as prefix. see example [config-multi-turn](examples/vllm/config-shared-prefix-multi-turn.yml)
2526

2627
## Roadmap
2728

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
load:
2+
type: constant
3+
num_workers: 2
4+
worker_max_concurrency: 10
5+
stages:
6+
- rate: 5
7+
duration: 10
8+
api:
9+
type: completion
10+
server:
11+
type: vllm
12+
model_name: HuggingFaceTB/SmolLM2-135M-Instruct
13+
base_url: http://0.0.0.0:8000
14+
ignore_eos: true
15+
tokenizer:
16+
pretrained_model_name_or_path: HuggingFaceTB/SmolLM2-135M-Instruct
17+
data:
18+
type: shared_prefix
19+
shared_prefix:
20+
num_groups: 2 # Number of distinct users
21+
num_prompts_per_group: 25 # Number of unique questions per user
22+
system_prompt_len: 100 # Length of the first prefix (in tokens), simulate initialization of a system prompt
23+
question_len: 50 # Length of the unique question part (in tokens)
24+
output_len: 50 # Target length for the model's generated output (in tokens)
25+
enable_multi_turn_chat: true # enable multi-turn chat, create user session for each group. The chat context will be appended for the each request in the group.
26+
metrics:
27+
type: prometheus
28+
prometheus:
29+
url: http://localhost:9090
30+
scrape_interval: 15
31+
report:
32+
request_lifecycle:
33+
summary: true
34+
per_stage: true
35+
per_request: true

inference_perf/apis/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
14+
from .base import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo, LazyLoadInferenceAPIData
1515
from .chat import ChatCompletionAPIData, ChatMessage
1616
from .completion import CompletionAPIData
1717

1818
__all__ = [
1919
"InferenceAPIData",
20+
"LazyLoadInferenceAPIData",
2021
"InferenceInfo",
2122
"RequestLifecycleMetric",
2223
"ErrorResponseInfo",

inference_perf/apis/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ class RequestLifecycleMetric(BaseModel):
4444

4545

4646
class InferenceAPIData(BaseModel):
47+
# loadgen should assign this request to prefered worker if possible
48+
prefered_worker_id: int = -1 # no prefered worker by default
49+
4750
@abstractmethod
4851
def get_api_type(self) -> APIType:
4952
raise NotImplementedError
@@ -64,3 +67,31 @@ async def process_failure(
6467
self, response: Optional[ClientResponse], config: APIConfig, tokenizer: CustomTokenizer, exception: Exception
6568
) -> Optional[InferenceInfo]:
6669
pass # no-op by default
70+
71+
72+
class LazyLoadInferenceAPIData(InferenceAPIData):
73+
"""
74+
InferenceAPIData that loads data lazily.
75+
This is useful for multiprocessing where the data cannot be pickled or need to be initialized in worker space.
76+
this class shouldn't go with any data but payload for data generator to return API data later.
77+
in most cases, generator should depends on data_index as reference. If more payload needed, try to extend this class.
78+
"""
79+
80+
data_index: int
81+
82+
def get_api_type(self) -> APIType:
83+
raise NotImplementedError("LazyLoadInferenceAPIData doesn't support this operation")
84+
85+
def get_route(self) -> str:
86+
raise NotImplementedError("LazyLoadInferenceAPIData doesn't support this operation")
87+
88+
async def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
89+
raise NotImplementedError("LazyLoadInferenceAPIData doesn't support this operation")
90+
91+
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
92+
raise NotImplementedError("LazyLoadInferenceAPIData doesn't support this operation")
93+
94+
async def process_failure(
95+
self, response: Optional[ClientResponse], config: APIConfig, tokenizer: CustomTokenizer, exception: Exception
96+
) -> Optional[InferenceInfo]:
97+
raise NotImplementedError("LazyLoadInferenceAPIData doesn't support this operation")

inference_perf/apis/completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class CompletionAPIData(InferenceAPIData):
2727
prompt: str
2828
max_tokens: int = 0
29-
output_token: str = ""
29+
model_response: str = ""
3030

3131
def get_api_type(self) -> APIType:
3232
return APIType.Completion
@@ -63,7 +63,7 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
6363
output_text += text
6464
prompt_len = tokenizer.count_tokens(self.prompt)
6565
output_len = tokenizer.count_tokens(output_text)
66-
self.output_token = output_text
66+
self.model_response = output_text
6767
return InferenceInfo(
6868
input_tokens=prompt_len,
6969
output_tokens=output_len,
@@ -77,5 +77,5 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
7777
return InferenceInfo(input_tokens=prompt_len)
7878
output_text = choices[0].get("text", "")
7979
output_len = tokenizer.count_tokens(output_text)
80-
self.output_token = output_text
80+
self.model_response = output_text
8181
return InferenceInfo(input_tokens=prompt_len, output_tokens=output_len)

inference_perf/apis/user_session.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import asyncio
3-
from typing import Any, Optional, Tuple
3+
from typing import Any, Optional
44
from pydantic import ConfigDict, Field
55

66
from aiohttp import ClientResponse
@@ -20,13 +20,13 @@ def __init__(self, user_session_id: str, context: str = ""):
2020
self.contexts = context if context else ""
2121
self._current_round = 0
2222
self._in_flight: asyncio.Lock = asyncio.Lock()
23-
self._waiting_rounds: asyncio.PriorityQueue[Tuple[int, asyncio.Future[bool]]] = asyncio.PriorityQueue()
23+
self._waiting_rounds: asyncio.Queue[asyncio.Future[bool]] = asyncio.Queue()
2424

2525
async def get_context(self, round: int) -> str:
2626
if not self._waiting_rounds.empty() or self._in_flight.locked():
2727
# entering waiting queue
2828
future: asyncio.Future[bool] = asyncio.Future()
29-
self._waiting_rounds.put_nowait((round, future))
29+
self._waiting_rounds.put_nowait(future)
3030
await future
3131
await self._in_flight.acquire()
3232
self._current_round += 1
@@ -36,7 +36,7 @@ def update_context(self, response: str) -> None:
3636
self.contexts = response
3737

3838
if not self._waiting_rounds.empty():
39-
_, future = self._waiting_rounds.get_nowait()
39+
future = self._waiting_rounds.get_nowait()
4040
future.set_result(True)
4141

4242
self._in_flight.release()
@@ -49,6 +49,7 @@ class UserSessionCompletionAPIData(CompletionAPIData):
4949

5050
async def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
5151
self._session_context = await self.user_session.get_context(self.target_round)
52+
# TODO: Currently, only prompt style (concat messages) support. Adding support for messages style payload.
5253
self.prompt = self._session_context + " " + self.prompt
5354
# TODO: The combined prompt (session context + current prompt) might exceed the model's
5455
# maximum sequence length. Implement truncation logic/strategy to prevent
@@ -62,7 +63,7 @@ def update_inference_info(self, inference_info: InferenceInfo) -> None:
6263
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
6364
inference_info = await super().process_response(response, config, tokenizer)
6465
self.update_inference_info(inference_info)
65-
self.user_session.update_context(self.prompt + " " + self.output_token)
66+
self.user_session.update_context(self.prompt + " " + self.model_response)
6667
return inference_info
6768

6869
async def process_failure(

inference_perf/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ class SharedPrefix(BaseModel):
5959
system_prompt_len: int = 100
6060
question_len: int = 50
6161
output_len: int = 50
62-
# create user session for each group. The chat context will be appended for the each request in the group.
63-
group_as_user_session: bool = False
62+
enable_multi_turn_chat: bool = False
6463

6564

6665
class DataConfig(BaseModel):

inference_perf/datagen/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import DataGenerator
14+
from .base import DataGenerator, LazyLoadDataMixin
1515
from .mock_datagen import MockDataGenerator
1616
from .hf_sharegpt_datagen import HFShareGPTDataGenerator
1717
from .synthetic_datagen import SyntheticDataGenerator
@@ -23,6 +23,7 @@
2323

2424
__all__ = [
2525
"DataGenerator",
26+
"LazyLoadDataMixin",
2627
"MockDataGenerator",
2728
"HFShareGPTDataGenerator",
2829
"SyntheticDataGenerator",

inference_perf/datagen/base.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from inference_perf.apis import InferenceAPIData
14+
from inference_perf.apis import InferenceAPIData, LazyLoadInferenceAPIData
1515
from inference_perf.utils.custom_tokenizer import CustomTokenizer
1616
from inference_perf.config import APIConfig, APIType, DataConfig, Distribution, SharedPrefix
1717
from abc import ABC, abstractmethod
@@ -64,3 +64,33 @@ def is_io_distribution_supported(self) -> bool:
6464
@abstractmethod
6565
def is_shared_prefix_supported(self) -> bool:
6666
raise NotImplementedError
67+
68+
# notify load gen whether request has prefered worker
69+
def is_prefered_worker_requested(self) -> bool:
70+
return False
71+
72+
73+
class LazyLoadDataMixin(ABC):
74+
"""
75+
Mixin for data generators that support lazy loading of InferenceAPIData.
76+
This is useful for multiprocessing where the actual InferenceAPIData objects
77+
might be large or unpickleable, or need to be initialized in the worker process.
78+
"""
79+
80+
@abstractmethod
81+
def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
82+
"""
83+
Returns the real InferenceAPIData object for the given data.
84+
This method is usually called by worker processes to lazily load data unless MP mode disabled
85+
"""
86+
raise NotImplementedError
87+
88+
@staticmethod
89+
def get_request(data_generator: DataGenerator, data: InferenceAPIData) -> InferenceAPIData:
90+
if isinstance(data, LazyLoadInferenceAPIData):
91+
if isinstance(data_generator, LazyLoadDataMixin):
92+
return data_generator.load_lazy_data(data)
93+
else:
94+
raise NotImplementedError("Data Generator doesn't support lazy loading of requested InferenceAPIData")
95+
else:
96+
return data

inference_perf/datagen/random_datagen.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numpy as np
15-
from inference_perf.apis import InferenceAPIData, CompletionAPIData
15+
from inference_perf.apis import InferenceAPIData, CompletionAPIData, LazyLoadInferenceAPIData
1616
from inference_perf.utils.custom_tokenizer import CustomTokenizer
1717
from inference_perf.utils.distribution import generate_distribution
1818
from .base import DataGenerator
@@ -80,7 +80,9 @@ def is_io_distribution_supported(self) -> bool:
8080
def is_shared_prefix_supported(self) -> bool:
8181
return False
8282

83-
def get_request(self, n: int) -> InferenceAPIData:
83+
def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
84+
n = data.data_index
85+
8486
if self.tokenizer is None:
8587
raise ValueError("Tokenizer is required for RandomDataGenerator")
8688

@@ -99,16 +101,5 @@ def get_data(self) -> Generator[InferenceAPIData, None, None]:
99101

100102
i = 0
101103
while True:
102-
prompt_text: str
103-
if self.input_lengths[i] <= 0:
104-
random_token_ids_list = []
105-
else:
106-
random_token_ids = np.random.randint(0, self.vocab_size, size=self.input_lengths[i], dtype=np.int64)
107-
random_token_ids_list = random_token_ids.tolist()
108-
prompt_text = self.tokenizer.get_tokenizer().decode(random_token_ids_list)
109-
110-
yield CompletionAPIData(
111-
prompt=prompt_text,
112-
max_tokens=self.output_lengths[i],
113-
)
104+
yield LazyLoadInferenceAPIData(data_index=i)
114105
i += 1

0 commit comments

Comments
 (0)