Skip to content

Commit 669fb62

Browse files
authored
feat: add rankings prompt and query token options (#498)
1 parent 42c6829 commit 669fb62

File tree

17 files changed

+548
-59
lines changed

17 files changed

+548
-59
lines changed

docs/cli_options.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ Mean number of passages per rankings entry (per query)(default 1).
124124
Stddev for passages per rankings entry (default 0).
125125
<br>_Default: `0`_
126126

127+
#### `--rankings-passages-prompt-token-mean` `<int>`
128+
129+
Mean number of tokens in a passage entry for rankings (default 550).
130+
<br>_Default: `550`_
131+
132+
#### `--rankings-passages-prompt-token-stddev` `<int>`
133+
134+
Stddev for number of tokens in a passage entry for rankings (default 0).
135+
<br>_Default: `0`_
136+
137+
#### `--rankings-query-prompt-token-mean` `<int>`
138+
139+
Mean number of tokens in a query entry for rankings (default 550).
140+
<br>_Default: `550`_
141+
142+
#### `--rankings-query-prompt-token-stddev` `<int>`
143+
144+
Stddev for number of tokens in a query entry for rankings (default 0).
145+
<br>_Default: `0`_
146+
127147
## Audio Input Options
128148

129149
#### `--audio-batch-size`, `--batch-size-audio` `<int>`

docs/tutorials/rankings.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ aiperf profile \
4343
--request-count 10 \
4444
--rankings-passages-mean 5 \
4545
--rankings-passages-stddev 1 \
46-
--prompt-input-tokens-mean 32 \
47-
--prompt-input-tokens-stddev 8
46+
--rankings-passages-prompt-token-mean 32 \
47+
--rankings-passages-prompt-token-stddev 8 \
48+
--rankings-query-prompt-token-mean 16 \
49+
--rankings-query-prompt-token-stddev 4
4850
```
4951

52+
> **Note:** The rankings-specific token options cannot be used together with `--prompt-input-tokens-mean` or `--prompt-input-tokens-stddev`. Use the rankings-specific options for controlling token counts in rankings queries and passages.
53+
5054
### Profile using Custom Inputs
5155

5256
Create a file named rankings.jsonl where each line represents a ranking request with a query and one or more passages.

src/aiperf/common/config/config_defaults.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class InputDefaults:
5959
NUM_DATASET_ENTRIES = 100
6060
RANKINGS_PASSAGES_MEAN = 1
6161
RANKINGS_PASSAGES_STDDEV = 0
62+
RANKINGS_PASSAGES_PROMPT_TOKEN_MEAN = 550
63+
RANKINGS_PASSAGES_PROMPT_TOKEN_STDDEV = 0
64+
RANKINGS_QUERY_PROMPT_TOKEN_MEAN = 550
65+
RANKINGS_QUERY_PROMPT_TOKEN_STDDEV = 0
6266

6367

6468
@dataclass(frozen=True)

src/aiperf/common/config/input_config.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,62 @@ def validate_goodput(self) -> Self:
331331
),
332332
] = InputDefaults.RANKINGS_PASSAGES_STDDEV
333333

334+
rankings_passages_prompt_token_mean: Annotated[
335+
int,
336+
Field(
337+
ge=1,
338+
description=(
339+
"Mean number of tokens in a passage entry for rankings (default 550)."
340+
),
341+
),
342+
CLIParameter(
343+
name=("--rankings-passages-prompt-token-mean",),
344+
group=_CLI_GROUP,
345+
),
346+
] = InputDefaults.RANKINGS_PASSAGES_PROMPT_TOKEN_MEAN
347+
348+
rankings_passages_prompt_token_stddev: Annotated[
349+
int,
350+
Field(
351+
ge=0,
352+
description=(
353+
"Stddev for number of tokens in a passage entry for rankings (default 0)."
354+
),
355+
),
356+
CLIParameter(
357+
name=("--rankings-passages-prompt-token-stddev",),
358+
group=_CLI_GROUP,
359+
),
360+
] = InputDefaults.RANKINGS_PASSAGES_PROMPT_TOKEN_STDDEV
361+
362+
rankings_query_prompt_token_mean: Annotated[
363+
int,
364+
Field(
365+
ge=1,
366+
description=(
367+
"Mean number of tokens in a query entry for rankings (default 550)."
368+
),
369+
),
370+
CLIParameter(
371+
name=("--rankings-query-prompt-token-mean",),
372+
group=_CLI_GROUP,
373+
),
374+
] = InputDefaults.RANKINGS_QUERY_PROMPT_TOKEN_MEAN
375+
376+
rankings_query_prompt_token_stddev: Annotated[
377+
int,
378+
Field(
379+
ge=0,
380+
description=(
381+
"Stddev for number of tokens in a query entry for rankings (default 0)."
382+
),
383+
),
384+
CLIParameter(
385+
name=("--rankings-query-prompt-token-stddev",),
386+
group=_CLI_GROUP,
387+
),
388+
] = InputDefaults.RANKINGS_QUERY_PROMPT_TOKEN_STDDEV
389+
334390
audio: AudioConfig = AudioConfig()
335391
image: ImageConfig = ImageConfig()
336392
video: VideoConfig = VideoConfig()

src/aiperf/common/config/user_config.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from aiperf.common.config.output_config import OutputConfig
2121
from aiperf.common.config.tokenizer_config import TokenizerConfig
2222
from aiperf.common.enums import CustomDatasetType, GPUTelemetryMode
23+
from aiperf.common.enums.plugin_enums import EndpointType
2324
from aiperf.common.enums.timing_enums import RequestRateMode, TimingMode
2425
from aiperf.common.utils import load_json_str
2526

@@ -398,3 +399,68 @@ def validate_concurrency_limits(self) -> Self:
398399
)
399400

400401
return self
402+
403+
@model_validator(mode="after")
404+
def validate_rankings_token_options(self) -> Self:
405+
"""Validate rankings token options usage."""
406+
407+
# Check if prompt input tokens have been changed from defaults
408+
prompt_tokens_modified = any(
409+
field in self.input.prompt.input_tokens.model_fields_set
410+
for field in ["mean", "stddev"]
411+
)
412+
413+
# Check if any rankings-specific token options have been changed from defaults
414+
rankings_token_fields = [
415+
"rankings_passages_prompt_token_mean",
416+
"rankings_passages_prompt_token_stddev",
417+
"rankings_query_prompt_token_mean",
418+
"rankings_query_prompt_token_stddev",
419+
]
420+
rankings_tokens_modified = any(
421+
field in self.input.model_fields_set for field in rankings_token_fields
422+
)
423+
424+
# Check if any rankings-specific passage options have been changed from defaults
425+
rankings_passages_fields = [
426+
"rankings_passages_mean",
427+
"rankings_passages_stddev",
428+
]
429+
rankings_passages_modified = any(
430+
field in self.input.model_fields_set for field in rankings_passages_fields
431+
)
432+
433+
rankings_options_modified = (
434+
rankings_tokens_modified or rankings_passages_modified
435+
)
436+
437+
endpoint_type_is_rankings = "rankings" in self.endpoint.type.lower()
438+
439+
# Validate that rankings options are only used with rankings endpoints
440+
rankings_endpoints = [
441+
endpoint_type
442+
for endpoint_type in EndpointType
443+
if "rankings" in endpoint_type.lower()
444+
]
445+
if rankings_options_modified and not endpoint_type_is_rankings:
446+
raise ValueError(
447+
f"Rankings-specific options (--rankings-passages-mean, --rankings-passages-stddev, "
448+
"--rankings-passages-prompt-token-mean, --rankings-passages-prompt-token-stddev, "
449+
"--rankings-query-prompt-token-mean, --rankings-query-prompt-token-stddev) "
450+
"can only be used with rankings endpoint types "
451+
f"Rankings endpoints: ({', '.join(rankings_endpoints)})."
452+
)
453+
454+
# Validate that prompt tokens and rankings tokens are not both set
455+
if prompt_tokens_modified and (
456+
rankings_tokens_modified or endpoint_type_is_rankings
457+
):
458+
raise ValueError(
459+
"The --prompt-input-tokens-mean/--prompt-input-tokens-stddev options "
460+
"cannot be used together with rankings-specific token options or the rankings endpoints"
461+
"Ranking options: (--rankings-passages-prompt-token-mean, --rankings-passages-prompt-token-stddev, "
462+
"--rankings-query-prompt-token-mean, --rankings-query-prompt-token-stddev, ). "
463+
f"Rankings endpoints: ({', '.join(rankings_endpoints)})."
464+
"Please use only one set of options."
465+
)
466+
return self

src/aiperf/common/enums/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
FrequencyMetricUnit,
5757
FrequencyMetricUnitInfo,
5858
GenericMetricUnit,
59+
MetricDictValueTypeT,
5960
MetricFlags,
6061
MetricOverTimeUnit,
6162
MetricOverTimeUnitInfo,
@@ -64,8 +65,10 @@
6465
MetricTimeUnit,
6566
MetricTimeUnitInfo,
6667
MetricType,
68+
MetricUnitT,
6769
MetricValueType,
6870
MetricValueTypeInfo,
71+
MetricValueTypeT,
6972
MetricValueTypeVarT,
7073
PowerMetricUnit,
7174
PowerMetricUnitInfo,
@@ -141,6 +144,7 @@
141144
"LifecycleState",
142145
"MediaType",
143146
"MessageType",
147+
"MetricDictValueTypeT",
144148
"MetricFlags",
145149
"MetricOverTimeUnit",
146150
"MetricOverTimeUnitInfo",
@@ -149,8 +153,10 @@
149153
"MetricTimeUnit",
150154
"MetricTimeUnitInfo",
151155
"MetricType",
156+
"MetricUnitT",
152157
"MetricValueType",
153158
"MetricValueTypeInfo",
159+
"MetricValueTypeT",
154160
"MetricValueTypeVarT",
155161
"ModelSelectionStrategy",
156162
"PowerMetricUnit",

src/aiperf/dataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
BaseLoader,
4141
BasePublicDatasetLoader,
4242
CustomDatasetT,
43+
Filename,
4344
MediaConversionMixin,
4445
MooncakeTrace,
4546
MooncakeTraceDatasetLoader,
@@ -70,6 +71,7 @@
7071
"CustomDatasetT",
7172
"DEFAULT_CORPUS_FILE",
7273
"DatasetManager",
74+
"Filename",
7375
"ImageGenerator",
7476
"MP3_SUPPORTED_SAMPLE_RATES",
7577
"MediaConversionMixin",

src/aiperf/dataset/composer/synthetic_rankings.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer):
2323

2424
self.session_id_generator = SessionIDGenerator(seed=config.input.random_seed)
2525
self._passages_rng = rng.derive("dataset.rankings.passages")
26+
self._passages_token_rng = rng.derive("dataset.rankings.passages.tokens")
27+
self._query_token_rng = rng.derive("dataset.rankings.query.tokens")
2628

2729
# Set default sampling strategy for synthetic rankings dataset if not explicitly set
2830
if self.config.input.dataset_sampling_strategy is None:
@@ -33,12 +35,6 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer):
3335
f"Using default sampling strategy for synthetic rankings dataset: {InputDefaults.DATASET_SAMPLING_STRATEGY}"
3436
)
3537

36-
if self.config.input.prompt.input_tokens.mean <= 0:
37-
raise ValueError(
38-
"Synthetic rankings data generation requires text prompts to be enabled. "
39-
"Please set --prompt-input-tokens-mean > 0."
40-
)
41-
4238
def create_dataset(self) -> list[Conversation]:
4339
"""Generate synthetic dataset for the rankings endpoint.
4440
@@ -64,17 +60,22 @@ def _create_turn(self, num_passages: int) -> Turn:
6460
"""Create a single ranking turn with one synthetic query and multiple synthetic passages."""
6561
turn = Turn()
6662

67-
query_text = self.prompt_generator.generate(
68-
mean=self.config.input.prompt.input_tokens.mean,
69-
stddev=self.config.input.prompt.input_tokens.stddev,
63+
query_text = self.prompt_generator.generate_prompt(
64+
self.prompt_generator.calculate_num_tokens(
65+
self.config.input.rankings_query_prompt_token_mean,
66+
self.config.input.rankings_query_prompt_token_stddev,
67+
)
7068
)
7169
query = Text(name="query", contents=[query_text])
7270

71+
# Generate passages with rankings-specific token counts (per passage)
7372
passages = Text(name="passages")
7473
for _ in range(num_passages):
75-
passage_text = self.prompt_generator.generate(
76-
mean=self.config.input.prompt.input_tokens.mean,
77-
stddev=self.config.input.prompt.input_tokens.stddev,
74+
passage_text = self.prompt_generator.generate_prompt(
75+
self.prompt_generator.calculate_num_tokens(
76+
self.config.input.rankings_passages_prompt_token_mean,
77+
self.config.input.rankings_passages_prompt_token_stddev,
78+
)
7879
)
7980
passages.contents.append(passage_text)
8081

src/aiperf/dataset/generator/prompt.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _create_prefix_prompt_pool(self) -> None:
123123
raise NotInitializedError("Tokenized corpus is not initialized.")
124124

125125
self._prefix_prompts = [
126-
self._generate_prompt(self.config.prefix_prompt.length)
126+
self.generate_prompt(self.config.prefix_prompt.length)
127127
for _ in range(self.config.prefix_prompt.pool_size)
128128
]
129129
self.debug(
@@ -137,6 +137,7 @@ def generate(
137137
hash_ids: list[int] | None = None,
138138
) -> str:
139139
"""Generate a synthetic prompt with the configuration parameters.
140+
Serves as a wrapper around other internal methods to provide a unified interface.
140141
141142
Args:
142143
mean: The mean of the normal distribution.
@@ -151,10 +152,24 @@ def generate(
151152
mean, hash_ids, self.config.input_tokens.block_size
152153
)
153154

154-
num_tokens = self._length_rng.sample_positive_normal_integer(mean, stddev)
155-
return self._generate_prompt(num_tokens)
155+
num_tokens = self.calculate_num_tokens(mean, stddev)
156+
return self.generate_prompt(num_tokens)
156157

157-
def _generate_prompt(self, num_tokens: int) -> str:
158+
def calculate_num_tokens(
159+
self,
160+
mean: int | None = None,
161+
stddev: int | None = None,
162+
) -> int:
163+
"""Calculate the number of tokens for a prompt based on a normal distribution.
164+
165+
Args:
166+
mean: The mean of the normal distribution.
167+
stddev: The standard deviation of the normal distribution.
168+
"""
169+
170+
return self._length_rng.sample_positive_normal_integer(mean, stddev)
171+
172+
def generate_prompt(self, num_tokens: int) -> str:
158173
"""Generate a prompt containing exactly `num_tokens` number of tokens.
159174
160175
Args:
@@ -175,7 +190,7 @@ def _generate_cached_prompt(
175190
Generate a prompt containing exactly `num_tokens` by reusing previously generated prompts
176191
stored in `_cache`. Each hash index in `hash_ids` corresponds to a block of
177192
`block_size` tokens. If a hash index is found in `_cache`, its stored prompt is reused.
178-
Otherwise, a new prompt is generated using `_generate_prompt()` and stored in `_cache`.
193+
Otherwise, a new prompt is generated using `generate_prompt()` and stored in `_cache`.
179194
180195
Args:
181196
num_tokens: The number of tokens required in the prompt.

src/aiperf/dataset/loader/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MultiTurnDatasetLoader,
3434
)
3535
from aiperf.dataset.loader.random_pool import (
36+
Filename,
3637
RandomPoolDatasetLoader,
3738
)
3839
from aiperf.dataset.loader.sharegpt import (
@@ -48,6 +49,7 @@
4849
"BaseLoader",
4950
"BasePublicDatasetLoader",
5051
"CustomDatasetT",
52+
"Filename",
5153
"MediaConversionMixin",
5254
"MooncakeTrace",
5355
"MooncakeTraceDatasetLoader",

0 commit comments

Comments
 (0)